[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-03-04 13:38:47 +00:00 committed by AdilZouitine
parent cc9a37f3f8
commit 45a03d253a
95 changed files with 3163 additions and 972 deletions

View File

@ -32,7 +32,11 @@ import numpy as np
import pandas as pd
import PIL
import torch
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
from skimage.metrics import (
mean_squared_error,
peak_signal_noise_ratio,
structural_similarity,
)
from tqdm import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
@ -81,7 +85,9 @@ def get_directory_size(directory: Path) -> int:
return total_size
def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> torch.Tensor:
def load_original_frames(
imgs_dir: Path, timestamps: list[float], fps: int
) -> torch.Tensor:
frames = []
for ts in timestamps:
idx = int(ts * fps)
@ -94,7 +100,11 @@ def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> t
def save_decoded_frames(
imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int
imgs_dir: Path,
save_dir: Path,
frames: torch.Tensor,
timestamps: list[float],
fps: int,
) -> None:
if save_dir.exists() and len(list(save_dir.glob("frame_*.png"))) == len(timestamps):
return
@ -104,7 +114,10 @@ def save_decoded_frames(
idx = int(ts * fps)
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame_{idx:06d}_decoded.png")
shutil.copyfile(imgs_dir / f"frame_{idx:06d}.png", save_dir / f"frame_{idx:06d}_original.png")
shutil.copyfile(
imgs_dir / f"frame_{idx:06d}.png",
save_dir / f"frame_{idx:06d}_original.png",
)
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
@ -116,11 +129,17 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
hf_dataset = dataset.hf_dataset.with_format(None)
# We only save images from the first camera
img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")]
img_keys = [
key for key in hf_dataset.features if key.startswith("observation.image")
]
imgs_dataset = hf_dataset.select_columns(img_keys[0])
for i, item in enumerate(
tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False)
tqdm(
imgs_dataset,
desc=f"saving {dataset.repo_id} first episode images",
leave=False,
)
):
img = item[img_keys[0]]
img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100)
@ -129,7 +148,9 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
break
def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> list[float]:
def sample_timestamps(
timestamps_mode: str, ep_num_images: int, fps: int
) -> list[float]:
# Start at 5 to allow for 2_frames_4_space and 6_frames
idx = random.randint(5, ep_num_images - 1)
match timestamps_mode:
@ -154,7 +175,9 @@ def decode_video_frames(
backend: str,
) -> torch.Tensor:
if backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
return decode_video_frames_torchvision(
video_path, timestamps, tolerance_s, backend
)
else:
raise NotImplementedError(backend)
@ -181,7 +204,9 @@ def benchmark_decoding(
}
with time_benchmark:
frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend)
frames = decode_video_frames(
video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend
)
result["load_time_video_ms"] = time_benchmark.result_ms / num_frames
with time_benchmark:
@ -190,12 +215,18 @@ def benchmark_decoding(
frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
for i in range(num_frames):
result["mse_values"].append(mean_squared_error(original_frames_np[i], frames_np[i]))
result["mse_values"].append(
mean_squared_error(original_frames_np[i], frames_np[i])
)
result["psnr_values"].append(
peak_signal_noise_ratio(original_frames_np[i], frames_np[i], data_range=1.0)
peak_signal_noise_ratio(
original_frames_np[i], frames_np[i], data_range=1.0
)
)
result["ssim_values"].append(
structural_similarity(original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0)
structural_similarity(
original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0
)
)
if save_frames and sample == 0:
@ -215,7 +246,9 @@ def benchmark_decoding(
# As these samples are independent, we run them in parallel threads to speed up the benchmark.
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(process_sample, i) for i in range(num_samples)]
for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False):
for future in tqdm(
as_completed(futures), total=num_samples, desc="samples", leave=False
):
result = future.result()
load_times_video_ms.append(result["load_time_video_ms"])
load_times_images_ms.append(result["load_time_images_ms"])
@ -275,9 +308,13 @@ def benchmark_encoding_decoding(
random.seed(seed)
benchmark_table = []
for timestamps_mode in tqdm(
decoding_cfg["timestamps_modes"], desc="decodings (timestamps_modes)", leave=False
decoding_cfg["timestamps_modes"],
desc="decodings (timestamps_modes)",
leave=False,
):
for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False):
for backend in tqdm(
decoding_cfg["backends"], desc="decodings (backends)", leave=False
):
benchmark_row = benchmark_decoding(
imgs_dir,
video_path,
@ -355,14 +392,23 @@ def main(
imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_")
# We only use the first episode
save_first_episode(imgs_dir, dataset)
for key, values in tqdm(encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False):
for key, values in tqdm(
encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False
):
for value in tqdm(values, desc=f"encodings ({key})", leave=False):
encoding_cfg = BASE_ENCODING.copy()
encoding_cfg["vcodec"] = video_codec
encoding_cfg["pix_fmt"] = pixel_format
encoding_cfg[key] = value
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
args_path = Path(
"_".join(str(value) for value in encoding_cfg.values())
)
video_path = (
output_dir
/ "videos"
/ args_path
/ f"{repo_id.replace('/', '_')}.mp4"
)
benchmark_table += benchmark_encoding_decoding(
dataset,
video_path,
@ -388,7 +434,9 @@ def main(
# Concatenate all results
df_list = [pd.read_csv(csv_path) for csv_path in file_paths]
concatenated_df = pd.concat(df_list, ignore_index=True)
concatenated_path = output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
concatenated_path = (
output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
)
concatenated_df.to_csv(concatenated_path, header=True, index=False)

View File

@ -18,7 +18,10 @@ import torch
from huggingface_hub import HfApi
import lerobot
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
LeRobotDatasetMetadata,
)
# We ported a number of existing datasets ourselves, use this to see the list:
print("List of available datasets:")
@ -26,7 +29,10 @@ pprint(lerobot.available_datasets)
# You can also browse through the datasets created/ported by the community on the hub using the hub api:
hub_api = HfApi()
repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])]
repo_ids = [
info.id
for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])
]
pprint(repo_ids)
# Or simply explore them in your web browser directly at:
@ -41,7 +47,9 @@ ds_meta = LeRobotDatasetMetadata(repo_id)
# structure of the dataset without downloading the actual data yet (only metadata files — which are
# lightweight).
print(f"Total number of episodes: {ds_meta.total_episodes}")
print(f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}")
print(
f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}"
)
print(f"Frames per second used during data collection: {ds_meta.fps}")
print(f"Robot type: {ds_meta.robot_type}")
print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n")

View File

@ -34,10 +34,14 @@ transforms = v2.Compose(
)
# Create another LeRobotDataset with the defined transformations
transformed_dataset = LeRobotDataset(dataset_repo_id, episodes=[0], image_transforms=transforms)
transformed_dataset = LeRobotDataset(
dataset_repo_id, episodes=[0], image_transforms=transforms
)
# Get a frame from the transformed dataset
transformed_frame = transformed_dataset[first_idx][transformed_dataset.meta.camera_keys[0]]
transformed_frame = transformed_dataset[first_idx][
transformed_dataset.meta.camera_keys[0]
]
# Create a directory to store output images
output_dir = Path("outputs/image_transforms")

View File

@ -12,7 +12,10 @@ import math
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
LeRobotDatasetMetadata,
)
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy

View File

@ -69,7 +69,9 @@ def load_raw_dataset(zarr_path: Path):
ReplayBuffer as DiffusionPolicyReplayBuffer,
)
except ModuleNotFoundError as e:
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
print(
"`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`"
)
raise e
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
@ -81,7 +83,9 @@ def calculate_coverage(zarr_data):
import pymunk
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
except ModuleNotFoundError as e:
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
print(
"`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`"
)
raise e
block_pos = zarr_data["state"][:, 2:4]
@ -111,7 +115,9 @@ def calculate_coverage(zarr_data):
]
space.add(*walls)
block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
block_body, block_shapes = PushTEnv.add_tee(
space, block_pos[i].tolist(), block_angle[i].item()
)
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
intersection_area = goal_geom.intersection(block_geom).area

View File

@ -164,7 +164,11 @@ available_real_world_datasets = [
]
available_datasets = sorted(
set(itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets))
set(
itertools.chain(
*available_datasets_per_env.values(), available_real_world_datasets
)
)
)
# lists all available policies from `lerobot/common/policies`
@ -205,9 +209,13 @@ available_policies_per_env = {
"aloha_real": ["act_aloha_real"],
}
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
env_task_pairs = [
(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks
]
env_dataset_pairs = [
(env, dataset) for env, datasets in available_datasets_per_env.items() for dataset in datasets
(env, dataset)
for env, datasets in available_datasets_per_env.items()
for dataset in datasets
]
env_dataset_policy_triplets = [
(env, dataset, policy)

View File

@ -45,12 +45,20 @@ def get_stats_einops_patterns(dataset, num_workers=0):
if key in dataset.meta.camera_keys:
# sanity check that images are channel first
_, c, h, w = batch[key].shape
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
assert (
c < h and c < w
), f"expect channel first images, but instead {batch[key].shape}"
# sanity check that images are float32 in range [0,1]
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
assert (
batch[key].dtype == torch.float32
), f"expect torch.float32, but instead {batch[key].dtype=}"
assert (
batch[key].max() <= 1
), f"expect pixels lower than 1, but instead {batch[key].max()=}"
assert (
batch[key].min() >= 0
), f"expect pixels greater than 1, but instead {batch[key].min()=}"
stats_patterns[key] = "b c h w -> c 1 1"
elif batch[key].ndim == 2:
@ -98,7 +106,11 @@ def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
running_item_count = 0 # for online mean computation
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
tqdm.tqdm(
dataloader,
total=ceil(max_num_samples / batch_size),
desc="Compute mean, min, max",
)
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
@ -113,9 +125,16 @@ def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
# and x is the current batch mean. Some rearrangement is then required to avoid risking
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
mean[key] = (
mean[key]
+ this_batch_size * (batch_mean - mean[key]) / running_item_count
)
max[key] = torch.maximum(
max[key], einops.reduce(batch[key], pattern, "max")
)
min[key] = torch.minimum(
min[key], einops.reduce(batch[key], pattern, "min")
)
if i == ceil(max_num_samples / batch_size) - 1:
break
@ -124,7 +143,9 @@ def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
running_item_count = 0 # for online std computation
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
tqdm.tqdm(
dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std"
)
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
@ -138,7 +159,9 @@ def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
# Numerically stable update step for mean computation (where the mean is over squared
# residuals).See notes in the mean computation loop above.
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
std[key] = (
std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
)
if i == ceil(max_num_samples / batch_size) - 1:
break
@ -177,13 +200,19 @@ def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
stats[data_key][stat_key] = einops.reduce(
torch.stack(
[ds.meta.stats[data_key][stat_key] for ds in ls_datasets if data_key in ds.meta.stats],
[
ds.meta.stats[data_key][stat_key]
for ds in ls_datasets
if data_key in ds.meta.stats
],
dim=0,
),
"n ... -> ...",
stat_key,
)
total_samples = sum(d.num_frames for d in ls_datasets if data_key in d.meta.stats)
total_samples = sum(
d.num_frames for d in ls_datasets if data_key in d.meta.stats
)
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
# dataset, then divide by total_samples to get the overall "mean".
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of

View File

@ -109,7 +109,9 @@ class AsyncImageWriter:
self._stopped = False
if num_threads <= 0 and num_processes <= 0:
raise ValueError("Number of threads and processes must be greater than zero.")
raise ValueError(
"Number of threads and processes must be greater than zero."
)
if self.num_processes == 0:
# Use threading
@ -123,12 +125,16 @@ class AsyncImageWriter:
# Use multiprocessing
self.queue = multiprocessing.JoinableQueue()
for _ in range(self.num_processes):
p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads))
p = multiprocessing.Process(
target=worker_process, args=(self.queue, self.num_threads)
)
p.daemon = True
p.start()
self.processes.append(p)
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
def save_image(
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path
):
if isinstance(image, torch.Tensor):
# Convert tensor to numpy array to minimize main process time
image = image.cpu().numpy()

View File

@ -68,7 +68,9 @@ from lerobot.common.robot_devices.robots.utils import Robot
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
CODEBASE_VERSION = "v2.0"
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
LEROBOT_HOME = Path(
os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")
).expanduser()
class LeRobotDatasetMetadata:
@ -108,7 +110,11 @@ class LeRobotDatasetMetadata:
@cached_property
def _hub_version(self) -> str | None:
return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
return (
None
if self.local_files_only
else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
)
@property
def _version(self) -> str:
@ -122,7 +128,9 @@ class LeRobotDatasetMetadata:
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
ep_chunk = self.get_episode_chunk(ep_index)
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
fpath = self.video_path.format(
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index
)
return Path(fpath)
def get_episode_chunk(self, ep_index: int) -> int:
@ -166,7 +174,11 @@ class LeRobotDatasetMetadata:
@property
def camera_keys(self) -> list[str]:
"""Keys to access visual modalities (regardless of their storage method)."""
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
return [
key
for key, ft in self.features.items()
if ft["dtype"] in ["video", "image"]
]
@property
def names(self) -> dict[str, list | dict]:
@ -215,7 +227,9 @@ class LeRobotDatasetMetadata:
task_index = self.task_to_task_index.get(task, None)
return task_index if task_index is not None else self.total_tasks
def save_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None:
def save_episode(
self, episode_index: int, episode_length: int, task: str, task_index: int
) -> None:
self.info["total_episodes"] += 1
self.info["total_frames"] += episode_length
@ -257,7 +271,9 @@ class LeRobotDatasetMetadata:
"""
for key in self.video_keys:
if not self.features[key].get("info", None):
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
video_path = self.root / self.get_video_file_path(
ep_index=0, vid_key=key
)
self.info["features"][key]["info"] = get_video_info(video_path)
write_json(self.info, self.root / INFO_PATH)
@ -315,7 +331,9 @@ class LeRobotDatasetMetadata:
features = {**features, **DEFAULT_FEATURES}
obj.tasks, obj.stats, obj.episodes = {}, {}, []
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
obj.info = create_empty_dataset_info(
CODEBASE_VERSION, fps, robot_type, features, use_videos
)
if len(obj.video_keys) > 0 and not use_videos:
raise ValueError()
write_json(obj.info, obj.root / INFO_PATH)
@ -451,7 +469,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.root.mkdir(exist_ok=True, parents=True)
# Load metadata
self.meta = LeRobotDatasetMetadata(self.repo_id, self.root, self.local_files_only)
self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.local_files_only
)
# Check version
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
@ -459,10 +479,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Load actual data
self.download_episodes(download_videos)
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
self.episode_data_index = get_episode_data_index(
self.meta.episodes, self.episodes
)
# Check timestamps
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
check_timestamps_sync(
self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s
)
# Setup delta_indices
if self.delta_timestamps is not None:
@ -508,7 +532,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
)
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset")
create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset")
create_branch(
repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset"
)
def pull_from_repo(
self,
@ -536,7 +562,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
files = None
ignore_patterns = None if download_videos else "videos/"
if self.episodes is not None:
files = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
files = [
str(self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes
]
if len(self.meta.video_keys) > 0 and download_videos:
video_files = [
str(self.meta.get_video_file_path(ep_idx, vid_key))
@ -554,7 +582,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
path = str(self.root / "data")
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
else:
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
files = [
str(self.root / self.meta.get_data_file_path(ep_idx))
for ep_idx in self.episodes
]
hf_dataset = load_dataset("parquet", data_files=files, split="train")
# TODO(aliberts): hf_dataset.set_format("torch")
@ -570,12 +601,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property
def num_frames(self) -> int:
"""Number of frames in selected episodes."""
return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames
return (
len(self.hf_dataset)
if self.hf_dataset is not None
else self.meta.total_frames
)
@property
def num_episodes(self) -> int:
"""Number of episodes selected."""
return len(self.episodes) if self.episodes is not None else self.meta.total_episodes
return (
len(self.episodes)
if self.episodes is not None
else self.meta.total_episodes
)
@property
def features(self) -> dict[str, dict]:
@ -589,16 +628,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
else:
return get_hf_features_from_features(self.features)
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
def _get_query_indices(
self, idx: int, ep_idx: int
) -> tuple[dict[str, list[int | bool]]]:
ep_start = self.episode_data_index["from"][ep_idx]
ep_end = self.episode_data_index["to"][ep_idx]
query_indices = {
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
key: [
max(ep_start.item(), min(ep_end.item() - 1, idx + delta))
for delta in delta_idx
]
for key, delta_idx in self.delta_indices.items()
}
padding = { # Pad values outside of current episode range
f"{key}_is_pad": torch.BoolTensor(
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
[
(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item())
for delta in delta_idx
]
)
for key, delta_idx in self.delta_indices.items()
}
@ -626,7 +673,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
if key not in self.meta.video_keys
}
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
def _query_videos(
self, query_timestamps: dict[str, list[float]], ep_idx: int
) -> dict:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
@ -656,7 +705,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
query_indices = None
if self.delta_indices is not None:
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
current_ep_idx = (
self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
)
query_indices, padding = self._get_query_indices(idx, current_ep_idx)
query_result = self._query_hf_dataset(query_indices)
item = {**item, **padding}
@ -692,19 +743,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
current_ep_idx = (
self.meta.total_episodes if episode_index is None else episode_index
)
return {
"size": 0,
**{key: current_ep_idx if key == "episode_index" else [] for key in self.features},
**{
key: current_ep_idx if key == "episode_index" else []
for key in self.features
},
}
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
def _get_image_file_path(
self, episode_index: int, image_key: str, frame_index: int
) -> Path:
fpath = DEFAULT_IMAGE_PATH.format(
image_key=image_key, episode_index=episode_index, frame_index=frame_index
)
return self.root / fpath
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
def _save_image(
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path
) -> None:
if self.image_writer is None:
if isinstance(image, torch.Tensor):
image = image.cpu().numpy()
@ -725,7 +785,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_buffer = self.create_episode_buffer()
frame_index = self.episode_buffer["size"]
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
timestamp = (
frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
)
self.episode_buffer["frame_index"].append(frame_index)
self.episode_buffer["timestamp"].append(timestamp)
@ -734,11 +796,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
raise ValueError(key)
if self.features[key]["dtype"] not in ["image", "video"]:
item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key]
item = (
frame[key].numpy()
if isinstance(frame[key], torch.Tensor)
else frame[key]
)
self.episode_buffer[key].append(item)
elif self.features[key]["dtype"] in ["image", "video"]:
img_path = self._get_image_file_path(
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
episode_index=self.episode_buffer["episode_index"],
image_key=key,
frame_index=frame_index,
)
if frame_index == 0:
img_path.parent.mkdir(parents=True, exist_ok=True)
@ -747,7 +815,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_buffer["size"] += 1
def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict | None = None) -> None:
def save_episode(
self, task: str, encode_videos: bool = True, episode_data: dict | None = None
) -> None:
"""
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
@ -814,7 +884,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
ep_dataset = datasets.Dataset.from_dict(
episode_dict, features=self.hf_features, split="train"
)
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
write_parquet(ep_dataset, ep_data_path)
@ -886,10 +958,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
return video_paths
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
def consolidate(
self, run_compute_stats: bool = True, keep_image_files: bool = False
) -> None:
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
self.episode_data_index = get_episode_data_index(
self.meta.episodes, self.episodes
)
check_timestamps_sync(
self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s
)
if len(self.meta.video_keys) > 0:
self.encode_videos()
@ -994,7 +1072,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
super().__init__()
self.repo_ids = repo_ids
self.root = Path(root) if root else LEROBOT_HOME
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
self.tolerances_s = (
tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
)
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.
self._datasets = [
@ -1071,7 +1151,13 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def features(self) -> datasets.Features:
features = {}
for dataset in self._datasets:
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
features.update(
{
k: v
for k, v in dataset.hf_features.items()
if k not in self.disabled_features
}
)
return features
@property
@ -1132,7 +1218,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
continue
break
else:
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
raise AssertionError(
"We expect the loop to break out as long as the index is within bounds."
)
item = self._datasets[dataset_idx][idx - start_idx]
item["dataset_index"] = torch.tensor(dataset_idx)
for data_key in self.disabled_features:

View File

@ -131,7 +131,9 @@ class OnlineBuffer(torch.utils.data.Dataset):
else:
self._delta_timestamps = None
def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]:
def _make_data_spec(
self, data_spec: dict[str, Any], buffer_capacity: int
) -> dict[str, dict[str, Any]]:
"""Makes the data spec for np.memmap."""
if any(k.startswith("_") for k in data_spec):
raise ValueError(
@ -154,14 +156,32 @@ class OnlineBuffer(torch.utils.data.Dataset):
OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()},
# Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied
# with real data rather than the dummy initialization.
OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)},
OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)},
OnlineBuffer.OCCUPANCY_MASK_KEY: {
"dtype": np.dtype("?"),
"shape": (buffer_capacity,),
},
OnlineBuffer.INDEX_KEY: {
"dtype": np.dtype("int64"),
"shape": (buffer_capacity,),
},
OnlineBuffer.FRAME_INDEX_KEY: {
"dtype": np.dtype("int64"),
"shape": (buffer_capacity,),
},
OnlineBuffer.EPISODE_INDEX_KEY: {
"dtype": np.dtype("int64"),
"shape": (buffer_capacity,),
},
OnlineBuffer.TIMESTAMP_KEY: {
"dtype": np.dtype("float64"),
"shape": (buffer_capacity,),
},
}
for k, v in data_spec.items():
complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])}
complete_data_spec[k] = {
"dtype": v["dtype"],
"shape": (buffer_capacity, *v["shape"]),
}
return complete_data_spec
def add_data(self, data: dict[str, np.ndarray]):
@ -188,7 +208,9 @@ class OnlineBuffer(torch.utils.data.Dataset):
# Shift the incoming indices if necessary.
if self.num_frames > 0:
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][
next_index - 1
]
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
data[OnlineBuffer.INDEX_KEY] += last_data_index + 1
@ -223,7 +245,11 @@ class OnlineBuffer(torch.utils.data.Dataset):
@property
def num_episodes(self) -> int:
return len(
np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
np.unique(
self._data[OnlineBuffer.EPISODE_INDEX_KEY][
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]
]
)
)
@property
@ -261,7 +287,9 @@ class OnlineBuffer(torch.utils.data.Dataset):
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY],
)
)[0]
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices]
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][
episode_data_indices
]
for data_key in self.delta_timestamps:
# Note: The logic in this loop is copied from `load_previous_and_future_frames`.
@ -278,7 +306,8 @@ class OnlineBuffer(torch.utils.data.Dataset):
# Check violated query timestamps are all outside the episode range.
assert (
(query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad])
(query_ts[is_pad] < episode_timestamps[0])
| (episode_timestamps[-1] < query_ts[is_pad])
).all(), (
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}"
") inside the episode range."
@ -293,7 +322,9 @@ class OnlineBuffer(torch.utils.data.Dataset):
def get_data_by_key(self, key: str) -> torch.Tensor:
"""Returns all data for a given data key as a Tensor."""
return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
return torch.from_numpy(
self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]
)
def compute_sampler_weights(
@ -324,13 +355,19 @@ def compute_sampler_weights(
- Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not
included here to avoid adding complexity.
"""
if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0):
raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.")
if len(offline_dataset) == 0 and (
online_dataset is None or len(online_dataset) == 0
):
raise ValueError(
"At least one of `offline_dataset` or `online_dataset` should be contain data."
)
if (online_dataset is None) ^ (online_sampling_ratio is None):
raise ValueError(
"`online_dataset` and `online_sampling_ratio` must be provided together or not at all."
)
offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio
offline_sampling_ratio = (
0 if online_sampling_ratio is None else 1 - online_sampling_ratio
)
weights = []

View File

@ -37,10 +37,16 @@ def check_chunks_compatible(chunks: tuple, shape: tuple):
assert c > 0
def rechunk_recompress_array(group, name, chunks=None, chunk_length=None, compressor=None, tmp_key="_temp"):
def rechunk_recompress_array(
group, name, chunks=None, chunk_length=None, compressor=None, tmp_key="_temp"
):
old_arr = group[name]
if chunks is None:
chunks = (chunk_length,) + old_arr.chunks[1:] if chunk_length is not None else old_arr.chunks
chunks = (
(chunk_length,) + old_arr.chunks[1:]
if chunk_length is not None
else old_arr.chunks
)
check_chunks_compatible(chunks, old_arr.shape)
if compressor is None:
@ -82,13 +88,18 @@ def get_optimal_chunks(shape, dtype, target_chunk_bytes=2e6, max_chunk_length=No
for i in range(len(shape) - 1):
this_chunk_bytes = itemsize * np.prod(rshape[:i])
next_chunk_bytes = itemsize * np.prod(rshape[: i + 1])
if this_chunk_bytes <= target_chunk_bytes and next_chunk_bytes > target_chunk_bytes:
if (
this_chunk_bytes <= target_chunk_bytes
and next_chunk_bytes > target_chunk_bytes
):
split_idx = i
rchunks = rshape[:split_idx]
item_chunk_bytes = itemsize * np.prod(rshape[:split_idx])
this_max_chunk_length = rshape[split_idx]
next_chunk_length = min(this_max_chunk_length, math.ceil(target_chunk_bytes / item_chunk_bytes))
next_chunk_length = min(
this_max_chunk_length, math.ceil(target_chunk_bytes / item_chunk_bytes)
)
rchunks.append(next_chunk_length)
len_diff = len(shape) - len(rchunks)
rchunks.extend([1] * len_diff)
@ -124,7 +135,13 @@ class ReplayBuffer:
root.require_group("data", overwrite=False)
meta = root.require_group("meta", overwrite=False)
if "episode_ends" not in meta:
meta.zeros("episode_ends", shape=(0,), dtype=np.int64, compressor=None, overwrite=False)
meta.zeros(
"episode_ends",
shape=(0,),
dtype=np.int64,
compressor=None,
overwrite=False,
)
return cls(root=root)
@classmethod
@ -193,7 +210,11 @@ class ReplayBuffer:
root = zarr.group(store=store)
# copy without recompression
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists
source=src_store,
dest=store,
source_path="/meta",
dest_path="/meta",
if_exists=if_exists,
)
data_group = root.create_group("data", overwrite=True)
if keys is None:
@ -201,7 +222,9 @@ class ReplayBuffer:
for key in keys:
value = src_root["data"][key]
cks = cls._resolve_array_chunks(chunks=chunks, key=key, array=value)
cpr = cls._resolve_array_compressor(compressors=compressors, key=key, array=value)
cpr = cls._resolve_array_compressor(
compressors=compressors, key=key, array=value
)
if cks == value.chunks and cpr == value.compressor:
# copy without recompression
this_path = "/data/" + key
@ -286,13 +309,17 @@ class ReplayBuffer:
meta_group = root.create_group("meta", overwrite=True)
# save meta, no chunking
for key, value in self.root["meta"].items():
_ = meta_group.array(name=key, data=value, shape=value.shape, chunks=value.shape)
_ = meta_group.array(
name=key, data=value, shape=value.shape, chunks=value.shape
)
# save data, chunk
data_group = root.create_group("data", overwrite=True)
for key, value in self.root["data"].items():
cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
cpr = self._resolve_array_compressor(
compressors=compressors, key=key, array=value
)
if isinstance(value, zarr.Array):
if cks == value.chunks and cpr == value.compressor:
# copy without recompression
@ -339,13 +366,19 @@ class ReplayBuffer:
@staticmethod
def resolve_compressor(compressor="default"):
if compressor == "default":
compressor = numcodecs.Blosc(cname="lz4", clevel=5, shuffle=numcodecs.Blosc.NOSHUFFLE)
compressor = numcodecs.Blosc(
cname="lz4", clevel=5, shuffle=numcodecs.Blosc.NOSHUFFLE
)
elif compressor == "disk":
compressor = numcodecs.Blosc("zstd", clevel=5, shuffle=numcodecs.Blosc.BITSHUFFLE)
compressor = numcodecs.Blosc(
"zstd", clevel=5, shuffle=numcodecs.Blosc.BITSHUFFLE
)
return compressor
@classmethod
def _resolve_array_compressor(cls, compressors: dict | str | numcodecs.abc.Codec, key, array):
def _resolve_array_compressor(
cls, compressors: dict | str | numcodecs.abc.Codec, key, array
):
# allows compressor to be explicitly set to None
cpr = "nil"
if isinstance(compressors, dict):
@ -404,7 +437,11 @@ class ReplayBuffer:
if self.backend == "zarr":
for key, value in np_data.items():
_ = meta_group.array(
name=key, data=value, shape=value.shape, chunks=value.shape, overwrite=True
name=key,
data=value,
shape=value.shape,
chunks=value.shape,
overwrite=True,
)
else:
meta_group.update(np_data)
@ -514,10 +551,18 @@ class ReplayBuffer:
# create array
if key not in self.data:
if is_zarr:
cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
cks = self._resolve_array_chunks(
chunks=chunks, key=key, array=value
)
cpr = self._resolve_array_compressor(
compressors=compressors, key=key, array=value
)
arr = self.data.zeros(
name=key, shape=new_shape, chunks=cks, dtype=value.dtype, compressor=cpr
name=key,
shape=new_shape,
chunks=cks,
dtype=value.dtype,
compressor=cpr,
)
else:
# copy data to prevent modify
@ -544,7 +589,9 @@ class ReplayBuffer:
# rechunk
if is_zarr and episode_ends.chunks[0] < episode_ends.shape[0]:
rechunk_recompress_array(self.meta, "episode_ends", chunk_length=int(episode_ends.shape[0] * 1.5))
rechunk_recompress_array(
self.meta, "episode_ends", chunk_length=int(episode_ends.shape[0] * 1.5)
)
def drop_episode(self):
is_zarr = self.backend == "zarr"

View File

@ -38,7 +38,9 @@ import argparse
from pathlib import Path
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub._download_raw import AVAILABLE_RAW_REPO_IDS
from lerobot.common.datasets.push_dataset_to_hub._download_raw import (
AVAILABLE_RAW_REPO_IDS,
)
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub
@ -73,7 +75,9 @@ def encode_datasets(
check_repo_id(raw_repo_id)
dataset_repo_id_push = get_push_repo_id_from_raw(raw_repo_id, push_repo)
dataset_raw_dir = raw_dir / raw_repo_id
dataset_dir = local_dir / dataset_repo_id_push if local_dir is not None else None
dataset_dir = (
local_dir / dataset_repo_id_push if local_dir is not None else None
)
encoding = {
"vcodec": vcodec,
"pix_fmt": pix_fmt,

View File

@ -133,7 +133,9 @@ class Jpeg2k(Codec):
)
def decode(self, buf, out=None):
return imagecodecs.jpeg2k_decode(buf, verbose=self.verbose, numthreads=self.numthreads, out=out)
return imagecodecs.jpeg2k_decode(
buf, verbose=self.verbose, numthreads=self.numthreads, out=out
)
class JpegXl(Codec):

View File

@ -44,7 +44,9 @@ from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
def get_cameras(hdf5_data):
# ignore depth channel, not currently handled
# TODO(rcadene): add depth
rgb_cameras = [key for key in hdf5_data["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
rgb_cameras = [
key for key in hdf5_data["/observations/images"].keys() if "depth" not in key
] # noqa: SIM118
return rgb_cameras
@ -73,7 +75,9 @@ def check_format(raw_dir) -> bool:
else:
assert data[f"/observations/images/{camera}"].ndim == 4
b, h, w, c = data[f"/observations/images/{camera}"].shape
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
assert (
c < h and c < w
), f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
def load_from_raw(
@ -134,14 +138,17 @@ def load_from_raw(
# encode images to a mp4 video
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
video_path = videos_dir / fname
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
encode_video_frames(
tmp_imgs_dir, video_path, fps, **(encoding or {})
)
# clean temporary images directory
shutil.rmtree(tmp_imgs_dir)
# store the reference to the video frame
ep_dict[img_key] = [
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
{"path": f"videos/{fname}", "timestamp": i / fps}
for i in range(num_frames)
]
else:
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
@ -181,15 +188,18 @@ def to_hf_dataset(data_dict, video) -> Dataset:
features[key] = Image()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
length=data_dict["observation.state"].shape[1],
feature=Value(dtype="float32", id=None),
)
if "observation.velocity" in data_dict:
features["observation.velocity"] = Sequence(
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
length=data_dict["observation.velocity"].shape[1],
feature=Value(dtype="float32", id=None),
)
if "observation.effort" in data_dict:
features["observation.effort"] = Sequence(
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
length=data_dict["observation.effort"].shape[1],
feature=Value(dtype="float32", id=None),
)
features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)

View File

@ -26,7 +26,9 @@ import torch
from datasets import Dataset, Features, Image, Sequence, Value
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
)
from lerobot.common.datasets.utils import (
hf_transform_to_torch,
)
@ -42,11 +44,19 @@ def check_format(raw_dir) -> bool:
return True
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
def load_from_raw(
raw_dir: Path,
videos_dir: Path,
fps: int,
video: bool,
episodes: list[int] | None = None,
):
# Load data stream that will be used as reference for the timestamps synchronization
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
if len(reference_files) == 0:
raise ValueError(f"Missing reference files for camera, starting with in '{raw_dir}'")
raise ValueError(
f"Missing reference files for camera, starting with in '{raw_dir}'"
)
# select first camera in alphanumeric order
reference_key = sorted(reference_files)[0].stem
reference_df = pd.read_parquet(raw_dir / f"{reference_key}.parquet")
@ -107,7 +117,9 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
df["timestamp"] = df["timestamp_utc"].map(lambda x: x.timestamp())
# each episode starts with timestamp 0 to match the ones from the video
df["timestamp"] = df.groupby("episode_index")["timestamp"].transform(lambda x: x - x.iloc[0])
df["timestamp"] = df.groupby("episode_index")["timestamp"].transform(
lambda x: x - x.iloc[0]
)
del df["timestamp_utc"]
@ -120,7 +132,9 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
expected_ep_ids = list(range(df["episode_index"].max() + 1))
if ep_ids != expected_ep_ids:
raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}")
raise ValueError(
f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}"
)
# Create symlink to raw videos directory (that needs to be absolute not relative)
videos_dir.parent.mkdir(parents=True, exist_ok=True)
@ -152,7 +166,9 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
data_dict[key] = torch.from_numpy(df[key].values)
# is vector
elif df[key].iloc[0].shape[0] > 1:
data_dict[key] = torch.stack([torch.from_numpy(x.copy()) for x in df[key].values])
data_dict[key] = torch.stack(
[torch.from_numpy(x.copy()) for x in df[key].values]
)
else:
raise ValueError(key)
@ -170,15 +186,18 @@ def to_hf_dataset(data_dict, video) -> Dataset:
features[key] = Image()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
length=data_dict["observation.state"].shape[1],
feature=Value(dtype="float32", id=None),
)
if "observation.velocity" in data_dict:
features["observation.velocity"] = Sequence(
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
length=data_dict["observation.velocity"].shape[1],
feature=Value(dtype="float32", id=None),
)
if "observation.effort" in data_dict:
features["observation.effort"] = Sequence(
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
length=data_dict["observation.effort"].shape[1],
feature=Value(dtype="float32", id=None),
)
features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)

View File

@ -143,7 +143,11 @@ def load_from_raw(
else:
state_keys.append(key)
lang_key = "language_instruction" if "language_instruction" in dataset.element_spec else None
lang_key = (
"language_instruction"
if "language_instruction" in dataset.element_spec
else None
)
print(" - image_keys: ", image_keys)
print(" - lang_key: ", lang_key)
@ -202,7 +206,9 @@ def load_from_raw(
# If lang_key is present, convert the entire tensor at once
if lang_key is not None:
ep_dict["language_instruction"] = [x.numpy().decode("utf-8") for x in episode[lang_key]]
ep_dict["language_instruction"] = [
x.numpy().decode("utf-8") for x in episode[lang_key]
]
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
@ -234,7 +240,8 @@ def load_from_raw(
# store the reference to the video frame
ep_dict[img_key] = [
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
{"path": f"videos/{fname}", "timestamp": i / fps}
for i in range(num_frames)
]
else:
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
@ -259,7 +266,9 @@ def to_hf_dataset(data_dict, video) -> Dataset:
for key in data_dict:
# check if vector state obs
if key.startswith("observation.") and "observation.images." not in key:
features[key] = Sequence(length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None))
features[key] = Sequence(
length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None)
)
# check if image obs
elif "observation.images." in key:
if video:

View File

@ -56,7 +56,9 @@ def check_format(raw_dir):
required_datasets.remove("meta/episode_ends")
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
assert all(
nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets
)
def load_from_raw(
@ -76,7 +78,9 @@ def load_from_raw(
ReplayBuffer as DiffusionPolicyReplayBuffer,
)
except ModuleNotFoundError as e:
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
print(
"`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`"
)
raise e
# as define in gmy-pusht env: https://github.com/huggingface/gym-pusht/blob/e0684ff988d223808c0a9dcfaba9dc4991791370/gym_pusht/envs/pusht.py#L174
success_threshold = 0.95 # 95% coverage,
@ -150,7 +154,9 @@ def load_from_raw(
]
space.add(*walls)
block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
block_body, block_shapes = PushTEnv.add_tee(
space, block_pos[i].tolist(), block_angle[i].item()
)
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
intersection_area = goal_geom.intersection(block_geom).area
@ -159,7 +165,9 @@ def load_from_raw(
reward[i] = np.clip(coverage / success_threshold, 0, 1)
success[i] = coverage > success_threshold
if keypoints_instead_of_image:
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
keypoints[i] = torch.from_numpy(
PushTEnv.get_keypoints(block_shapes).flatten()
)
# last step of demonstration is considered done
done[-1] = True
@ -184,7 +192,8 @@ def load_from_raw(
# store the reference to the video frame
ep_dict[img_key] = [
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
{"path": f"videos/{fname}", "timestamp": i / fps}
for i in range(num_frames)
]
else:
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
@ -193,7 +202,9 @@ def load_from_raw(
if keypoints_instead_of_image:
ep_dict["observation.environment_state"] = keypoints
ep_dict["action"] = actions[from_idx:to_idx]
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
ep_dict["episode_index"] = torch.tensor(
[ep_idx] * num_frames, dtype=torch.int64
)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
# ep_dict["next.observation.image"] = image[1:],
@ -220,7 +231,8 @@ def to_hf_dataset(data_dict, video, keypoints_instead_of_image: bool = False):
features["observation.image"] = Image()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
length=data_dict["observation.state"].shape[1],
feature=Value(dtype="float32", id=None),
)
if keypoints_instead_of_image:
features["observation.environment_state"] = Sequence(
@ -261,7 +273,9 @@ def from_raw_to_lerobot_format(
if fps is None:
fps = 10
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, keypoints_instead_of_image, encoding)
data_dict = load_from_raw(
raw_dir, videos_dir, fps, video, episodes, keypoints_instead_of_image, encoding
)
hf_dataset = to_hf_dataset(data_dict, video, keypoints_instead_of_image)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {

View File

@ -26,7 +26,9 @@ from datasets import Dataset, Features, Image, Sequence, Value
from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import (
register_codecs,
)
from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
concatenate_episodes,
@ -61,7 +63,9 @@ def check_format(raw_dir) -> bool:
nb_frames = zarr_data["data/camera0_rgb"].shape[0]
required_datasets.remove("meta/episode_ends")
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
assert all(
nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets
)
def load_from_raw(
@ -79,7 +83,9 @@ def load_from_raw(
end_pose = torch.from_numpy(zarr_data["data/robot0_demo_end_pose"][:])
start_pos = torch.from_numpy(zarr_data["data/robot0_demo_start_pose"][:])
eff_pos = torch.from_numpy(zarr_data["data/robot0_eef_pos"][:])
eff_rot_axis_angle = torch.from_numpy(zarr_data["data/robot0_eef_rot_axis_angle"][:])
eff_rot_axis_angle = torch.from_numpy(
zarr_data["data/robot0_eef_rot_axis_angle"][:]
)
gripper_width = torch.from_numpy(zarr_data["data/robot0_gripper_width"][:])
states_pos = torch.cat([eff_pos, eff_rot_axis_angle], dim=1)
@ -129,24 +135,31 @@ def load_from_raw(
save_images_concurrently(imgs_array, tmp_imgs_dir)
# encode images to a mp4 video
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
encode_video_frames(
tmp_imgs_dir, video_path, fps, **(encoding or {})
)
# clean temporary images directory
shutil.rmtree(tmp_imgs_dir)
# store the reference to the video frame
ep_dict[img_key] = [
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
{"path": f"videos/{fname}", "timestamp": i / fps}
for i in range(num_frames)
]
else:
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
ep_dict["observation.state"] = state
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
ep_dict["episode_index"] = torch.tensor(
[ep_idx] * num_frames, dtype=torch.int64
)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
ep_dict["episode_data_index_from"] = torch.tensor([from_idx] * num_frames)
ep_dict["episode_data_index_to"] = torch.tensor([from_idx + num_frames] * num_frames)
ep_dict["episode_data_index_to"] = torch.tensor(
[from_idx + num_frames] * num_frames
)
ep_dict["end_pose"] = end_pose[from_idx:to_idx]
ep_dict["start_pos"] = start_pos[from_idx:to_idx]
ep_dict["gripper_width"] = gripper_width[from_idx:to_idx]
@ -172,7 +185,8 @@ def to_hf_dataset(data_dict, video):
features["observation.image"] = Image()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
length=data_dict["observation.state"].shape[1],
feature=Value(dtype="float32", id=None),
)
features["episode_index"] = Value(dtype="int64", id=None)
features["frame_index"] = Value(dtype="int64", id=None)
@ -192,7 +206,8 @@ def to_hf_dataset(data_dict, video):
length=data_dict["start_pos"].shape[1], feature=Value(dtype="float32", id=None)
)
features["gripper_width"] = Sequence(
length=data_dict["gripper_width"].shape[1], feature=Value(dtype="float32", id=None)
length=data_dict["gripper_width"].shape[1],
feature=Value(dtype="float32", id=None),
)
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))

View File

@ -45,7 +45,9 @@ def concatenate_episodes(ep_dicts):
return data_dict
def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4):
def save_images_concurrently(
imgs_array: numpy.array, out_dir: Path, max_workers: int = 4
):
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
@ -55,7 +57,10 @@ def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers
num_images = len(imgs_array)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
[
executor.submit(save_image, imgs_array[i], i, out_dir)
for i in range(num_images)
]
def get_default_encoding() -> dict:
@ -64,7 +69,8 @@ def get_default_encoding() -> dict:
return {
k: v.default
for k, v in signature.parameters.items()
if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"]
if v.default is not inspect.Parameter.empty
and k in ["vcodec", "pix_fmt", "g", "crf"]
}
@ -77,7 +83,9 @@ def check_repo_id(repo_id: str) -> None:
# TODO(aliberts): remove
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
def calculate_episode_data_index(
hf_dataset: datasets.Dataset,
) -> Dict[str, torch.Tensor]:
"""
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.

View File

@ -40,7 +40,10 @@ from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
def check_format(raw_dir):
keys = {"actions", "rewards", "dones"}
nested_keys = {"observations": {"rgb", "state"}, "next_observations": {"rgb", "state"}}
nested_keys = {
"observations": {"rgb", "state"},
"next_observations": {"rgb", "state"},
}
xarm_files = list(raw_dir.glob("*.pkl"))
assert len(xarm_files) > 0
@ -53,11 +56,17 @@ def check_format(raw_dir):
# Check for consistent lengths in nested keys
expected_len = len(dataset_dict["actions"])
assert all(len(dataset_dict[key]) == expected_len for key in keys if key in dataset_dict)
assert all(
len(dataset_dict[key]) == expected_len for key in keys if key in dataset_dict
)
for key, subkeys in nested_keys.items():
nested_dict = dataset_dict.get(key, {})
assert all(len(nested_dict[subkey]) == expected_len for subkey in subkeys if subkey in nested_dict)
assert all(
len(nested_dict[subkey]) == expected_len
for subkey in subkeys
if subkey in nested_dict
)
def load_from_raw(
@ -122,13 +131,18 @@ def load_from_raw(
shutil.rmtree(tmp_imgs_dir)
# store the reference to the video frame
ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)]
ep_dict[img_key] = [
{"path": f"videos/{fname}", "timestamp": i / fps}
for i in range(num_frames)
]
else:
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
ep_dict["observation.state"] = state
ep_dict["action"] = action
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
ep_dict["episode_index"] = torch.tensor(
[ep_idx] * num_frames, dtype=torch.int64
)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
# ep_dict["next.observation.image"] = next_image
@ -153,7 +167,8 @@ def to_hf_dataset(data_dict, video):
features["observation.image"] = Image()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
length=data_dict["observation.state"].shape[1],
feature=Value(dtype="float32", id=None),
)
features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)

View File

@ -43,7 +43,10 @@ class EpisodeAwareSampler:
):
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
indices.extend(
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
range(
start_index.item() + drop_n_first_frames,
end_index.item() - drop_n_last_frames,
)
)
self.indices = indices

View File

@ -58,7 +58,9 @@ class RandomSubsetApply(Transform):
elif not isinstance(n_subset, int):
raise TypeError("n_subset should be an int or None")
elif not (1 <= n_subset <= len(transforms)):
raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]")
raise ValueError(
f"n_subset should be in the interval [1, {len(transforms)}]"
)
self.transforms = transforms
total = sum(p)
@ -119,16 +121,22 @@ class SharpnessJitter(Transform):
def _check_input(self, sharpness):
if isinstance(sharpness, (int, float)):
if sharpness < 0:
raise ValueError("If sharpness is a single number, it must be non negative.")
raise ValueError(
"If sharpness is a single number, it must be non negative."
)
sharpness = [1.0 - sharpness, 1.0 + sharpness]
sharpness[0] = max(sharpness[0], 0.0)
elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2:
sharpness = [float(v) for v in sharpness]
else:
raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.")
raise TypeError(
f"{sharpness=} should be a single number or a sequence with length 2."
)
if not 0.0 <= sharpness[0] <= sharpness[1]:
raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.")
raise ValueError(
f"sharpnesss values should be between (0., inf), but got {sharpness}."
)
return float(sharpness[0]), float(sharpness[1])

View File

@ -44,9 +44,15 @@ EPISODES_PATH = "meta/episodes.jsonl"
STATS_PATH = "meta/stats.json"
TASKS_PATH = "meta/tasks.jsonl"
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
DEFAULT_VIDEO_PATH = (
"videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
)
DEFAULT_PARQUET_PATH = (
"data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
)
DEFAULT_IMAGE_PATH = (
"images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
)
DATASET_CARD_TEMPLATE = """
---
@ -112,7 +118,9 @@ def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
serialized_dict = {key: value.tolist() for key, value in flatten_dict(stats).items()}
serialized_dict = {
key: value.tolist() for key, value in flatten_dict(stats).items()
}
return unflatten_dict(serialized_dict)
@ -170,14 +178,19 @@ def load_stats(local_dir: Path) -> dict:
def load_tasks(local_dir: Path) -> dict:
tasks = load_jsonlines(local_dir / TASKS_PATH)
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
return {
item["task_index"]: item["task"]
for item in sorted(tasks, key=lambda x: x["task_index"])
}
def load_episodes(local_dir: Path) -> dict:
return load_jsonlines(local_dir / EPISODES_PATH)
def load_image_as_numpy(fpath: str | Path, dtype="float32", channel_first: bool = True) -> np.ndarray:
def load_image_as_numpy(
fpath: str | Path, dtype="float32", channel_first: bool = True
) -> np.ndarray:
img = PILImage.open(fpath).convert("RGB")
img_array = np.array(img, dtype=dtype)
if channel_first: # (H, W, C) -> (C, H, W)
@ -235,7 +248,10 @@ class BackwardCompatibilityError(Exception):
def check_version_compatibility(
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
repo_id: str,
version_to_check: str,
current_version: str,
enforce_breaking_major: bool = True,
) -> None:
current_major, _ = _get_major_minor(current_version)
major_to_check, _ = _get_major_minor(version_to_check)
@ -361,7 +377,9 @@ def create_empty_dataset_info(
def get_episode_data_index(
episode_dicts: list[dict], episodes: list[int] | None = None
) -> dict[str, torch.Tensor]:
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)}
episode_lengths = {
ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)
}
if episodes is not None:
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
@ -382,7 +400,9 @@ def calculate_total_episode(
return total_episodes
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]:
def calculate_episode_data_index(
hf_dataset: datasets.Dataset,
) -> dict[str, torch.Tensor]:
episode_lengths = []
table = hf_dataset.data.table
total_episodes = calculate_total_episode(hf_dataset)
@ -424,7 +444,9 @@ def check_timestamps_sync(
# Track original indices before masking
original_indices = torch.arange(len(diffs))
filtered_indices = original_indices[mask]
outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance) # .squeeze()
outside_tolerance_filtered_indices = torch.nonzero(
~filtered_within_tolerance
) # .squeeze()
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
episode_indices = torch.stack(hf_dataset["episode_index"])
@ -449,7 +471,10 @@ def check_timestamps_sync(
def check_delta_timestamps(
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
delta_timestamps: dict[str, list[float]],
fps: int,
tolerance_s: float,
raise_value_error: bool = True,
) -> bool:
"""This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance.
This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be
@ -457,10 +482,14 @@ def check_delta_timestamps(
"""
outside_tolerance = {}
for key, delta_ts in delta_timestamps.items():
within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts]
within_tolerance = [
abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts
]
if not all(within_tolerance):
outside_tolerance[key] = [
ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within
ts
for ts, is_within in zip(delta_ts, within_tolerance, strict=True)
if not is_within
]
if len(outside_tolerance) > 0:
@ -478,7 +507,9 @@ def check_delta_timestamps(
return True
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
def get_delta_indices(
delta_timestamps: dict[str, list[float]], fps: int
) -> dict[str, list[int]]:
delta_indices = {}
for key, delta_ts in delta_timestamps.items():
delta_indices[key] = [round(d * fps) for d in delta_ts]
@ -543,7 +574,9 @@ def create_lerobot_dataset_card(
],
)
card_template = (importlib.resources.files("lerobot.common.datasets") / "card_template.md").read_text()
card_template = (
importlib.resources.files("lerobot.common.datasets") / "card_template.md"
).read_text()
return DatasetCard.from_template(
card_data=card_data,

View File

@ -117,7 +117,10 @@ DATASETS = {
"single_task": "Place the battery into the slot of the remote controller.",
**ALOHA_STATIC_INFO,
},
"aloha_static_candy": {"single_task": "Pick up the candy and unwrap it.", **ALOHA_STATIC_INFO},
"aloha_static_candy": {
"single_task": "Pick up the candy and unwrap it.",
**ALOHA_STATIC_INFO,
},
"aloha_static_coffee": {
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.",
**ALOHA_STATIC_INFO,
@ -166,13 +169,22 @@ DATASETS = {
"single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
**ALOHA_STATIC_INFO,
},
"aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO},
"aloha_sim_insertion_scripted": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
"aloha_static_ziploc_slide": {
"single_task": "Slide open the ziploc bag.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_scripted": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_scripted_image": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_human": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
"aloha_sim_insertion_human": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_human_image": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
@ -193,10 +205,19 @@ DATASETS = {
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
**ALOHA_STATIC_INFO,
},
"pusht": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
"pusht_image": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
"pusht": {
"single_task": "Push the T-shaped block onto the T-shaped target.",
**PUSHT_INFO,
},
"pusht_image": {
"single_task": "Push the T-shaped block onto the T-shaped target.",
**PUSHT_INFO,
},
"unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO},
"unitreeh1_rearrange_objects": {"single_task": "Put the object into the bin.", **UNITREEH_INFO},
"unitreeh1_rearrange_objects": {
"single_task": "Put the object into the bin.",
**UNITREEH_INFO,
},
"unitreeh1_two_robot_greeting": {
"single_task": "Greet the other robot with a high five.",
**UNITREEH_INFO,
@ -206,13 +227,31 @@ DATASETS = {
**UNITREEH_INFO,
},
"xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_replay": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_replay_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_image": {
"single_task": "Pick up the cube and lift it.",
**XARM_INFO,
},
"xarm_lift_medium_replay": {
"single_task": "Pick up the cube and lift it.",
**XARM_INFO,
},
"xarm_lift_medium_replay_image": {
"single_task": "Pick up the cube and lift it.",
**XARM_INFO,
},
"xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_replay": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_replay_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_image": {
"single_task": "Push the cube onto the target.",
**XARM_INFO,
},
"xarm_push_medium_replay": {
"single_task": "Push the cube onto the target.",
**XARM_INFO,
},
"xarm_push_medium_replay_image": {
"single_task": "Push the cube onto the target.",
**XARM_INFO,
},
"umi_cup_in_the_wild": {
"single_task": "Put the cup on the plate.",
"license": "apache-2.0",

View File

@ -218,7 +218,9 @@ def get_features_from_hf_dataset(
dtype = ft.feature.dtype
shape = (ft.length,)
motor_names = (
robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
robot_config["names"][key]
if robot_config
else [f"motor_{i}" for i in range(ft.length)]
)
assert len(motor_names) == shape[0]
names = {"motors": motor_names}
@ -242,11 +244,15 @@ def get_features_from_hf_dataset(
return features
def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
def add_task_index_by_episodes(
dataset: Dataset, tasks_by_episodes: dict
) -> tuple[Dataset, list[str]]:
df = dataset.to_pandas()
tasks = list(set(tasks_by_episodes.values()))
tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)}
episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()}
episodes_to_task_index = {
ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()
}
df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int)
features = dataset.features
@ -263,10 +269,19 @@ def add_task_index_from_tasks_col(
# HACK: This is to clean some of the instructions in our version of Open X datasets
prefix_to_clean = "tf.Tensor(b'"
suffix_to_clean = "', shape=(), dtype=string)"
df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean)
df[tasks_col] = (
df[tasks_col]
.str.removeprefix(prefix_to_clean)
.str.removesuffix(suffix_to_clean)
)
# Create task_index col
tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict()
tasks_by_episode = (
df.groupby("episode_index")[tasks_col]
.unique()
.apply(lambda x: x.tolist())
.to_dict()
)
tasks = df[tasks_col].unique().tolist()
tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)}
df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int)
@ -291,7 +306,9 @@ def split_parquet_by_episodes(
for ep_chunk in range(total_chunks):
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(
episode_chunk=ep_chunk
)
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
for ep_idx in range(ep_chunk_start, ep_chunk_end):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
@ -323,7 +340,9 @@ def move_videos(
videos_moved = False
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
if len(video_files) == 0:
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")]
video_files = [
str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")
]
videos_moved = True # Videos have already been moved
assert len(video_files) == total_episodes * len(video_keys)
@ -354,7 +373,9 @@ def move_videos(
target_path = DEFAULT_VIDEO_PATH.format(
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
)
video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
video_file = V1_VIDEO_FILE.format(
video_key=vid_key, episode_index=ep_idx
)
if len(video_dirs) == 1:
video_path = video_dirs[0] / video_file
else:
@ -371,7 +392,9 @@ def move_videos(
subprocess.run(["git", "push"], cwd=work_dir, check=True)
def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None:
def fix_lfs_video_files_tracking(
work_dir: Path, lfs_untracked_videos: list[str]
) -> None:
"""
HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
there's no other option than to download the actual files and reupload them with lfs tracking.
@ -379,7 +402,12 @@ def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]
for i in range(0, len(lfs_untracked_videos), 100):
files = lfs_untracked_videos[i : i + 100]
try:
subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True)
subprocess.run(
["git", "rm", "--cached", *files],
cwd=work_dir,
capture_output=True,
check=True,
)
except subprocess.CalledProcessError as e:
print("git rm --cached ERROR:")
print(e.stderr)
@ -390,10 +418,14 @@ def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]
subprocess.run(["git", "push"], cwd=work_dir, check=True)
def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None:
def fix_gitattributes(
work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path
) -> None:
shutil.copyfile(clean_gittatributes, current_gittatributes)
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True)
subprocess.run(
["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True
)
subprocess.run(["git", "push"], cwd=work_dir, check=True)
@ -402,7 +434,17 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
repo_url = f"https://huggingface.co/datasets/{repo_id}"
env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
subprocess.run(
["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)],
[
"git",
"clone",
"--branch",
branch,
"--single-branch",
"--depth",
"1",
repo_url,
str(work_dir),
],
check=True,
env=env,
)
@ -410,13 +452,19 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]:
lfs_tracked_files = subprocess.run(
["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
["git", "lfs", "ls-files", "-n"],
cwd=work_dir,
capture_output=True,
text=True,
check=True,
)
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
return [f for f in video_files if f not in lfs_tracked_files]
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
def get_videos_info(
repo_id: str, local_dir: Path, video_keys: list[str], branch: str
) -> dict:
# Assumes first episode
video_files = [
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
@ -424,7 +472,11 @@ def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch
]
hub_api = HfApi()
hub_api.snapshot_download(
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
repo_id=repo_id,
repo_type="dataset",
local_dir=local_dir,
revision=branch,
allow_patterns=video_files,
)
videos_info_dict = {}
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
@ -451,7 +503,11 @@ def convert_dataset(
hub_api = HfApi()
hub_api.snapshot_download(
repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/"
repo_id=repo_id,
repo_type="dataset",
revision=v1,
local_dir=v1x_dir,
ignore_patterns="videos*/",
)
branch = "main"
if test_branch:
@ -483,19 +539,31 @@ def convert_dataset(
if single_task:
tasks_by_episodes = {ep_idx: single_task for ep_idx in episode_indices}
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
tasks_by_episodes = {
ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()
}
elif tasks_path:
tasks_by_episodes = load_json(tasks_path)
tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
tasks_by_episodes = {
int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()
}
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
tasks_by_episodes = {
ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()
}
elif tasks_col:
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col)
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(
dataset, tasks_col
)
else:
raise ValueError
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
assert set(tasks) == {
task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks
}
tasks = [
{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)
]
write_jsonlines(tasks, v20_dir / TASKS_PATH)
features["task_index"] = {
"dtype": "int64",
@ -509,14 +577,25 @@ def convert_dataset(
dataset = dataset.remove_columns(video_keys)
clean_gitattr = Path(
hub_api.hf_hub_download(
repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
repo_id=GITATTRIBUTES_REF,
repo_type="dataset",
local_dir=local_dir,
filename=".gitattributes",
)
).absolute()
with tempfile.TemporaryDirectory() as tmp_video_dir:
move_videos(
repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch
repo_id,
video_keys,
total_episodes,
total_chunks,
Path(tmp_video_dir),
clean_gitattr,
branch,
)
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
videos_info = get_videos_info(
repo_id, v1x_dir, video_keys=video_keys, branch=branch
)
for key in video_keys:
features[key]["shape"] = (
videos_info[key].pop("video.height"),
@ -524,15 +603,22 @@ def convert_dataset(
videos_info[key].pop("video.channels"),
)
features[key]["video_info"] = videos_info[key]
assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
assert math.isclose(
videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3
)
if "encoding" in metadata_v1:
assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
assert (
videos_info[key]["video.pix_fmt"]
== metadata_v1["encoding"]["pix_fmt"]
)
else:
assert metadata_v1.get("video", 0) == 0
videos_info = None
# Split data into 1 parquet file by episode
episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
episode_lengths = split_parquet_by_episodes(
dataset, total_episodes, total_chunks, v20_dir
)
if robot_config is not None:
robot_type = robot_config.type
@ -543,7 +629,11 @@ def convert_dataset(
# Episodes
episodes = [
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
{
"episode_index": ep_idx,
"tasks": tasks_by_episodes[ep_idx],
"length": episode_lengths[ep_idx],
}
for ep_idx in episode_indices
]
write_jsonlines(episodes, v20_dir / EPISODES_PATH)
@ -566,16 +656,27 @@ def convert_dataset(
}
write_json(metadata_v2_0, v20_dir / INFO_PATH)
convert_stats_to_json(v1x_dir, v20_dir)
card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs)
card = create_lerobot_dataset_card(
tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs
)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
hub_api.delete_folder(
repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch
)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
hub_api.delete_folder(
repo_id=repo_id,
path_in_repo="meta_data",
repo_type="dataset",
revision=branch,
)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
hub_api.delete_folder(
repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch
)
hub_api.upload_folder(
repo_id=repo_id,

View File

@ -227,7 +227,9 @@ def get_audio_info(video_path: Path | str) -> dict:
"json",
str(video_path),
]
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
result = subprocess.run(
ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
@ -241,7 +243,9 @@ def get_audio_info(video_path: Path | str) -> dict:
"has_audio": True,
"audio.channels": audio_stream_info.get("channels", None),
"audio.codec": audio_stream_info.get("codec_name", None),
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
"audio.bit_rate": int(audio_stream_info["bit_rate"])
if audio_stream_info.get("bit_rate")
else None,
"audio.sample_rate": int(audio_stream_info["sample_rate"])
if audio_stream_info.get("sample_rate")
else None,
@ -263,7 +267,9 @@ def get_video_info(video_path: Path | str) -> dict:
"json",
str(video_path),
]
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
result = subprocess.run(
ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}")

View File

@ -70,7 +70,9 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
return env
def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
def make_maniskill_env(
cfg: DictConfig, n_envs: int | None = None
) -> gym.vector.VectorEnv | None:
"""Make ManiSkill3 gym environment"""
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
@ -87,7 +89,9 @@ def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector
# state should have the size of 25
# env = ConvertToLeRobotEnv(env, n_envs)
# env = PixelWrapper(cfg, env, n_envs)
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
env._max_episode_steps = env.max_episode_steps = (
50 # gym_utils.find_max_episode_steps_value(env)
)
env.unwrapped.metadata["render_fps"] = 20
return env
@ -114,7 +118,11 @@ class PixelWrapper(gym.Wrapper):
def _get_obs(self, obs):
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
self._frames.append(frame)
return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)}
return {
"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(
self.env.device
)
}
def reset(self, seed):
obs, info = self.env.reset() # (seed=seed)
@ -148,7 +156,9 @@ class ConvertToLeRobotEnv(gym.Wrapper):
images = torch.concat(images, axis=-1)
# flatten the rest of the data which should just be state data
observation = common.flatten_state_dict(observation, use_torch=True, device=self.base_env.device)
observation = common.flatten_state_dict(
observation, use_torch=True, device=self.base_env.device
)
ret = dict()
ret["state"] = observation
ret["pixels"] = images

View File

@ -84,7 +84,9 @@ class Logger:
pretrained_model_dir_name = "pretrained_model"
training_state_file_name = "training_state.pth"
def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None):
def __init__(
self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None
):
"""
Args:
log_dir: The directory to save all logs and training outputs to.
@ -104,7 +106,9 @@ class Logger:
enable_wandb = cfg.get("wandb", {}).get("enable", False)
run_offline = not enable_wandb or not project
if run_offline:
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
logging.info(
colored("Logs will be saved locally.", "yellow", attrs=["bold"])
)
self._wandb = None
else:
os.environ["WANDB_SILENT"] = "true"
@ -130,7 +134,9 @@ class Logger:
# Handle custom step key for rl asynchronous training.
self._wandb_custom_step_key: set[str] | None = None
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
logging.info(
f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}"
)
self._wandb = wandb
@classmethod
@ -151,7 +157,9 @@ class Logger:
"""
return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name
def save_model(self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None):
def save_model(
self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None
):
"""Save the weights of the Policy model using PyTorchModelHubMixin.
The weights are saved in a folder called "pretrained_model" under the checkpoint directory.
@ -221,22 +229,30 @@ class Logger:
else f"{self._group.replace(':', '_').replace('/', '_')}-{self._cfg.seed}-{identifier}"
)
self.save_model(
checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name
checkpoint_dir / self.pretrained_model_dir_name,
policy,
wandb_artifact_name=wandb_artifact_name,
)
self.save_training_state(
checkpoint_dir, train_step, optimizer, scheduler, interaction_step
)
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler, interaction_step)
os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir)
def load_last_training_state(self, optimizer: Optimizer | dict, scheduler: LRScheduler | None) -> int:
def load_last_training_state(
self, optimizer: Optimizer | dict, scheduler: LRScheduler | None
) -> int:
"""
Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and
random state, and return the global training step.
"""
training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name)
training_state = torch.load(
self.last_checkpoint_dir / self.training_state_file_name
)
# For the case where the optimizer is a dictionary of optimizers (e.g., sac)
if type(training_state["optimizer"]) is dict:
assert set(training_state["optimizer"].keys()) == set(optimizer.keys()), (
"Optimizer dictionaries do not have the same keys during resume!"
)
assert set(training_state["optimizer"].keys()) == set(
optimizer.keys()
), "Optimizer dictionaries do not have the same keys during resume!"
for k, v in training_state["optimizer"].items():
optimizer[k].load_state_dict(v)
else:
@ -248,10 +264,18 @@ class Logger:
"The checkpoint contains a scheduler state_dict, but no LRScheduler was provided."
)
# Small hack to get the expected keys: use `get_global_random_state`.
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
set_global_random_state(
{k: training_state[k] for k in get_global_random_state()}
)
return training_state["step"]
def log_dict(self, d, step: int | None = None, mode="train", custom_step_key: str | None = None):
def log_dict(
self,
d,
step: int | None = None,
mode="train",
custom_step_key: str | None = None,
):
"""Log a dictionary of metrics to WandB."""
assert mode in {"train", "eval"}
# TODO(alexander-soare): Add local text log.
@ -280,12 +304,20 @@ class Logger:
continue
# Do not log the custom step key itself.
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
if (
self._wandb_custom_step_key is not None
and k in self._wandb_custom_step_key
):
continue
if custom_step_key is not None:
value_custom_step = d[custom_step_key]
self._wandb.log({f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step})
self._wandb.log(
{
f"{mode}/{k}": v,
f"{mode}/{custom_step_key}": value_custom_step,
}
)
continue
self._wandb.log(data={f"{mode}/{k}": v}, step=step)

View File

@ -74,7 +74,9 @@ class ACTPolicy(PreTrainedPolicy):
self.model = ACT(config)
if config.temporal_ensemble_coeff is not None:
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
self.temporal_ensembler = ACTTemporalEnsembler(
config.temporal_ensemble_coeff, config.chunk_size
)
self.reset()
@ -156,7 +158,8 @@ class ACTPolicy(PreTrainedPolicy):
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
F.l1_loss(batch["action"], actions_hat, reduction="none")
* ~batch["action_is_pad"].unsqueeze(-1)
).mean()
loss_dict = {"l1_loss": l1_loss.item()}
@ -166,7 +169,12 @@ class ACTPolicy(PreTrainedPolicy):
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
(
-0.5
* (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())
)
.sum(-1)
.mean()
)
loss_dict["kld_loss"] = mean_kld.item()
loss = l1_loss + mean_kld * self.config.kl_weight
@ -220,7 +228,9 @@ class ACTTemporalEnsembler:
```
"""
self.chunk_size = chunk_size
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
self.ensemble_weights = torch.exp(
-temporal_ensemble_coeff * torch.arange(chunk_size)
)
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
self.reset()
@ -236,7 +246,9 @@ class ACTTemporalEnsembler:
time steps, and pop/return the next batch of actions in the sequence.
"""
self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device)
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(
device=actions.device
)
if self.ensembled_actions is None:
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
# time step of the episode.
@ -244,19 +256,34 @@ class ACTTemporalEnsembler:
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
# operations later.
self.ensembled_actions_count = torch.ones(
(self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
(self.chunk_size, 1),
dtype=torch.long,
device=self.ensembled_actions.device,
)
else:
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
# the online update for those entries.
self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]
self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]
self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size)
self.ensembled_actions *= self.ensemble_weights_cumsum[
self.ensembled_actions_count - 1
]
self.ensembled_actions += (
actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
)
self.ensembled_actions /= self.ensemble_weights_cumsum[
self.ensembled_actions_count
]
self.ensembled_actions_count = torch.clamp(
self.ensembled_actions_count + 1, max=self.chunk_size
)
# The last action, which has no prior online average, needs to get concatenated onto the end.
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
self.ensembled_actions = torch.cat(
[self.ensembled_actions, actions[:, -1:]], dim=1
)
self.ensembled_actions_count = torch.cat(
[self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])]
[
self.ensembled_actions_count,
torch.ones_like(self.ensembled_actions_count[-1:]),
]
)
# "Consume" the first action.
action, self.ensembled_actions, self.ensembled_actions_count = (
@ -322,7 +349,9 @@ class ACT(nn.Module):
config.dim_model,
)
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
self.vae_encoder_latent_output_proj = nn.Linear(
config.dim_model, config.latent_dim * 2
)
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
# dimension.
num_input_token_encoder = 1 + config.chunk_size
@ -330,20 +359,28 @@ class ACT(nn.Module):
num_input_token_encoder += 1
self.register_buffer(
"vae_encoder_pos_enc",
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
create_sinusoidal_pos_embedding(
num_input_token_encoder, config.dim_model
).unsqueeze(0),
)
# Backbone for image feature extraction.
if self.config.image_features:
backbone_model = getattr(torchvision.models, config.vision_backbone)(
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
replace_stride_with_dilation=[
False,
False,
config.replace_final_stride_with_dilation,
],
weights=config.pretrained_backbone_weights,
norm_layer=FrozenBatchNorm2d,
)
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
# feature map).
# Note: The forward method of this returns a dict: {"feature_map": output}.
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
self.backbone = IntermediateLayerGetter(
backbone_model, return_layers={"layer4": "feature_map"}
)
# Transformer (acts as VAE decoder when training with the variational objective).
self.encoder = ACTEncoder(config)
@ -389,7 +426,9 @@ class ACT(nn.Module):
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
def forward(
self, batch: dict[str, Tensor]
) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
`batch` should have the following structure:
@ -428,7 +467,9 @@ class ACT(nn.Module):
if self.config.robot_state_feature:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
action_embed = self.vae_encoder_action_input_proj(
batch["action"]
) # (B, S, D)
if self.config.robot_state_feature:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
@ -469,20 +510,24 @@ class ACT(nn.Module):
# When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
batch["observation.state"].device
)
latent_sample = torch.zeros(
[batch_size, self.config.latent_dim], dtype=torch.float32
).to(batch["observation.state"].device)
# Prepare transformer encoder inputs.
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
encoder_in_pos_embed = list(
self.encoder_1d_feature_pos_embed.weight.unsqueeze(1)
)
# Robot state token.
if self.config.robot_state_feature:
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
# Environment state token.
if self.config.env_state_feature:
encoder_in_tokens.append(
self.encoder_env_state_input_proj(batch["observation.environment_state"])
self.encoder_env_state_input_proj(
batch["observation.environment_state"]
)
)
# Camera observation features and positional embeddings.
@ -491,19 +536,29 @@ class ACT(nn.Module):
all_cam_pos_embeds = []
for cam_index in range(batch["observation.images"].shape[-4]):
cam_features = self.backbone(batch["observation.images"][:, cam_index])["feature_map"]
cam_features = self.backbone(batch["observation.images"][:, cam_index])[
"feature_map"
]
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
# buffer
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(
dtype=cam_features.dtype
)
cam_features = self.encoder_img_feat_input_proj(
cam_features
) # (B, C, h, w)
all_cam_features.append(cam_features)
all_cam_pos_embeds.append(cam_pos_embed)
# Concatenate camera observation feature maps and positional embeddings along the width dimension,
# and move to (sequence, batch, dim).
all_cam_features = torch.cat(all_cam_features, axis=-1)
encoder_in_tokens.extend(einops.rearrange(all_cam_features, "b c h w -> (h w) b c"))
encoder_in_tokens.extend(
einops.rearrange(all_cam_features, "b c h w -> (h w) b c")
)
all_cam_pos_embeds = torch.cat(all_cam_pos_embeds, axis=-1)
encoder_in_pos_embed.extend(einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c"))
encoder_in_pos_embed.extend(
einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c")
)
# Stack all tokens along the sequence dimension.
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
@ -538,12 +593,21 @@ class ACTEncoder(nn.Module):
def __init__(self, config: ACTConfig, is_vae_encoder: bool = False):
super().__init__()
self.is_vae_encoder = is_vae_encoder
num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)])
num_layers = (
config.n_vae_encoder_layers
if self.is_vae_encoder
else config.n_encoder_layers
)
self.layers = nn.ModuleList(
[ACTEncoderLayer(config) for _ in range(num_layers)]
)
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
def forward(
self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
self,
x: Tensor,
pos_embed: Tensor | None = None,
key_padding_mask: Tensor | None = None,
) -> Tensor:
for layer in self.layers:
x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
@ -554,7 +618,9 @@ class ACTEncoder(nn.Module):
class ACTEncoderLayer(nn.Module):
def __init__(self, config: ACTConfig):
super().__init__()
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
self.self_attn = nn.MultiheadAttention(
config.dim_model, config.n_heads, dropout=config.dropout
)
# Feed forward layers.
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
@ -569,7 +635,9 @@ class ACTEncoderLayer(nn.Module):
self.activation = get_activation_fn(config.feedforward_activation)
self.pre_norm = config.pre_norm
def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor:
def forward(
self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
) -> Tensor:
skip = x
if self.pre_norm:
x = self.norm1(x)
@ -594,7 +662,9 @@ class ACTDecoder(nn.Module):
def __init__(self, config: ACTConfig):
"""Convenience module for running multiple decoder layers followed by normalization."""
super().__init__()
self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)])
self.layers = nn.ModuleList(
[ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]
)
self.norm = nn.LayerNorm(config.dim_model)
def forward(
@ -606,7 +676,10 @@ class ACTDecoder(nn.Module):
) -> Tensor:
for layer in self.layers:
x = layer(
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
x,
encoder_out,
decoder_pos_embed=decoder_pos_embed,
encoder_pos_embed=encoder_pos_embed,
)
if self.norm is not None:
x = self.norm(x)
@ -616,8 +689,12 @@ class ACTDecoder(nn.Module):
class ACTDecoderLayer(nn.Module):
def __init__(self, config: ACTConfig):
super().__init__()
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
self.self_attn = nn.MultiheadAttention(
config.dim_model, config.n_heads, dropout=config.dropout
)
self.multihead_attn = nn.MultiheadAttention(
config.dim_model, config.n_heads, dropout=config.dropout
)
# Feed forward layers.
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
@ -658,7 +735,9 @@ class ACTDecoderLayer(nn.Module):
if self.pre_norm:
x = self.norm1(x)
q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
x = self.self_attn(q, k, value=x)[
0
] # select just the output, not the attention weights
x = skip + self.dropout1(x)
if self.pre_norm:
skip = x
@ -695,9 +774,14 @@ def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tenso
"""
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)]
return [
position / np.power(10000, 2 * (hid_j // 2) / dimension)
for hid_j in range(dimension)
]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)])
sinusoid_table = np.array(
[get_position_angle_vec(pos_i) for pos_i in range(num_positions)]
)
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.from_numpy(sinusoid_table).float()
@ -742,7 +826,9 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
inverse_frequency = self._temperature ** (
2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
2
* (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2)
/ self.dimension
)
x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
@ -750,9 +836,15 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
# Note: this stack then flatten operation results in interleaved sine and cosine terms.
# pos_embed_x and pos_embed_y are (1, H, W, C // 2).
pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3)
pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3)
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W)
pos_embed_x = torch.stack(
(x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1
).flatten(3)
pos_embed_y = torch.stack(
(y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1
).flatten(3)
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(
0, 3, 1, 2
) # (1, C, H, W)
return pos_embed

View File

@ -132,7 +132,11 @@ class DiffusionPolicy(PreTrainedPolicy):
if len(self._queues["action"]) == 0:
# stack n latest observations from the queue
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
batch = {
k: torch.stack(list(self._queues[k]), dim=1)
for k in batch
if k in self._queues
}
actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary?
@ -189,7 +193,9 @@ class DiffusionModel(nn.Module):
if self.config.env_state_feature:
global_cond_dim += self.config.env_state_feature.shape[0]
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
self.unet = DiffusionConditionalUnet1d(
config, global_cond_dim=global_cond_dim * config.n_obs_steps
)
self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type,
@ -209,7 +215,10 @@ class DiffusionModel(nn.Module):
# ========= inference ============
def conditional_sample(
self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None
self,
batch_size: int,
global_cond: Tensor | None = None,
generator: torch.Generator | None = None,
) -> Tensor:
device = get_device_from_parameters(self)
dtype = get_dtype_from_parameters(self)
@ -232,7 +241,9 @@ class DiffusionModel(nn.Module):
global_cond=global_cond,
)
# Compute previous image: x_t -> x_t-1
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
sample = self.noise_scheduler.step(
model_output, t, sample, generator=generator
).prev_sample
return sample
@ -244,27 +255,39 @@ class DiffusionModel(nn.Module):
if self.config.image_features:
if self.config.use_separate_rgb_encoder_per_camera:
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
images_per_camera = einops.rearrange(
batch["observation.images"], "b s n ... -> n (b s) ..."
)
img_features_list = torch.cat(
[
encoder(images)
for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True)
for encoder, images in zip(
self.rgb_encoder, images_per_camera, strict=True
)
]
)
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features).
img_features = einops.rearrange(
img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
img_features_list,
"(n b s) ... -> b s (n ...)",
b=batch_size,
s=n_obs_steps,
)
else:
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
img_features = self.rgb_encoder(
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
einops.rearrange(
batch["observation.images"], "b s n ... -> (b s n) ..."
)
)
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features).
img_features = einops.rearrange(
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
img_features,
"(b s n) ... -> b s (n ...)",
b=batch_size,
s=n_obs_steps,
)
global_cond_feats.append(img_features)
@ -350,7 +373,9 @@ class DiffusionModel(nn.Module):
elif self.config.prediction_type == "sample":
target = batch["action"]
else:
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
raise ValueError(
f"Unsupported prediction type {self.config.prediction_type}"
)
loss = F.mse_loss(pred, target, reduction="none")
@ -410,7 +435,9 @@ class SpatialSoftmax(nn.Module):
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and causes a small degradation in pc_success of pre-trained models.
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
pos_x, pos_y = np.meshgrid(
np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)
)
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
# register as buffer so it's moved to the correct device.
@ -452,7 +479,9 @@ class DiffusionRgbEncoder(nn.Module):
# Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
self.maybe_random_crop = torchvision.transforms.RandomCrop(
config.crop_shape
)
else:
self.maybe_random_crop = self.center_crop
else:
@ -473,7 +502,9 @@ class DiffusionRgbEncoder(nn.Module):
self.backbone = _replace_submodules(
root_module=self.backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features // 16, num_channels=x.num_features
),
)
# Set up pooling and final layers.
@ -515,7 +546,9 @@ class DiffusionRgbEncoder(nn.Module):
def _replace_submodules(
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
root_module: nn.Module,
predicate: Callable[[nn.Module], bool],
func: Callable[[nn.Module], nn.Module],
) -> nn.Module:
"""
Args:
@ -528,7 +561,11 @@ def _replace_submodules(
if predicate(root_module):
return func(root_module)
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
replace_list = [
k.split(".")
for k, m in root_module.named_modules(remove_duplicate=True)
if predicate(m)
]
for *parents, k in replace_list:
parent_module = root_module
if len(parents) > 0:
@ -543,7 +580,9 @@ def _replace_submodules(
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
assert not any(
predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)
)
return root_module
@ -571,7 +610,9 @@ class DiffusionConv1dBlock(nn.Module):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
nn.Conv1d(
inp_channels, out_channels, kernel_size, padding=kernel_size // 2
),
nn.GroupNorm(n_groups, out_channels),
nn.Mish(),
)
@ -594,9 +635,13 @@ class DiffusionConditionalUnet1d(nn.Module):
# Encoder for the diffusion timestep.
self.diffusion_step_encoder = nn.Sequential(
DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim),
nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4),
nn.Linear(
config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4
),
nn.Mish(),
nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim),
nn.Linear(
config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim
),
)
# The FiLM conditioning dimension.
@ -621,10 +666,16 @@ class DiffusionConditionalUnet1d(nn.Module):
self.down_modules.append(
nn.ModuleList(
[
DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(
dim_in, dim_out, **common_res_block_kwargs
),
DiffusionConditionalResidualBlock1d(
dim_out, dim_out, **common_res_block_kwargs
),
# Downsample as long as it is not the last block.
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
nn.Conv1d(dim_out, dim_out, 3, 2, 1)
if not is_last
else nn.Identity(),
]
)
)
@ -633,10 +684,14 @@ class DiffusionConditionalUnet1d(nn.Module):
self.mid_modules = nn.ModuleList(
[
DiffusionConditionalResidualBlock1d(
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
config.down_dims[-1],
config.down_dims[-1],
**common_res_block_kwargs,
),
DiffusionConditionalResidualBlock1d(
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
config.down_dims[-1],
config.down_dims[-1],
**common_res_block_kwargs,
),
]
)
@ -649,10 +704,16 @@ class DiffusionConditionalUnet1d(nn.Module):
nn.ModuleList(
[
# dim_in * 2, because it takes the encoder's skip connection as well
DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(
dim_in * 2, dim_out, **common_res_block_kwargs
),
DiffusionConditionalResidualBlock1d(
dim_out, dim_out, **common_res_block_kwargs
),
# Upsample as long as it is not the last block.
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1)
if not is_last
else nn.Identity(),
]
)
)
@ -726,17 +787,23 @@ class DiffusionConditionalResidualBlock1d(nn.Module):
self.use_film_scale_modulation = use_film_scale_modulation
self.out_channels = out_channels
self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
self.conv1 = DiffusionConv1dBlock(
in_channels, out_channels, kernel_size, n_groups=n_groups
)
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
self.conv2 = DiffusionConv1dBlock(
out_channels, out_channels, kernel_size, n_groups=n_groups
)
# A final convolution for dimension matching the residual (if needed).
self.residual_conv = (
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
nn.Conv1d(in_channels, out_channels, 1)
if in_channels != out_channels
else nn.Identity()
)
def forward(self, x: Tensor, cond: Tensor) -> Tensor:

View File

@ -7,7 +7,9 @@ from torch import Tensor, nn
from .configuration_classifier import ClassifierConfig
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
@ -15,7 +17,10 @@ class ClassifierOutput:
"""Wrapper for classifier outputs with additional metadata."""
def __init__(
self, logits: Tensor, probabilities: Optional[Tensor] = None, hidden_states: Optional[Tensor] = None
self,
logits: Tensor,
probabilities: Optional[Tensor] = None,
hidden_states: Optional[Tensor] = None,
):
self.logits = logits
self.probabilities = probabilities
@ -43,12 +48,14 @@ class Classifier(
name = "classifier"
def __init__(self, config: ClassifierConfig):
from transformers import AutoImageProcessor, AutoModel
from transformers import AutoModel
super().__init__()
self.config = config
# self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
encoder = AutoModel.from_pretrained(
self.config.model_name, trust_remote_code=True
)
# Extract vision model if we're given a multimodal model
if hasattr(encoder, "vision_model"):
logging.info("Multimodal model detected - using vision encoder only")
@ -74,7 +81,9 @@ class Classifier(
self.feature_dim = self.encoder.fc.in_features
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
elif hasattr(self.encoder.config, "hidden_sizes"):
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
self.feature_dim = self.encoder.config.hidden_sizes[
-1
] # Last channel dimension
else:
raise ValueError("Unsupported CNN architecture")
@ -94,14 +103,19 @@ class Classifier(
if hasattr(self.encoder.config, "hidden_size"):
input_dim = self.encoder.config.hidden_size
else:
raise ValueError("Unsupported transformer architecture since hidden_size is not found")
raise ValueError(
"Unsupported transformer architecture since hidden_size is not found"
)
self.classifier_head = nn.Sequential(
nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
nn.Dropout(self.config.dropout_rate),
nn.LayerNorm(self.config.hidden_dim),
nn.ReLU(),
nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes),
nn.Linear(
self.config.hidden_dim,
1 if self.config.num_classes == 2 else self.config.num_classes,
),
)
self.classifier_head = self.classifier_head.to(self.config.device)
@ -127,7 +141,10 @@ class Classifier(
return features
else: # Transformer models
outputs = self.encoder(processed)
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
if (
hasattr(outputs, "pooler_output")
and outputs.pooler_output is not None
):
return outputs.pooler_output
return outputs.last_hidden_state[:, 0, :]
@ -143,7 +160,9 @@ class Classifier(
else:
probabilities = torch.softmax(logits, dim=-1)
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
return ClassifierOutput(
logits=logits, probabilities=probabilities, hidden_states=encoder_outputs
)
def predict_reward(self, x, threshold=0.6):
if self.config.num_classes == 2:

View File

@ -59,7 +59,9 @@ class SACPolicy(
config.input_normalization_params
)
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, input_normalization_params
config.input_shapes,
config.input_normalization_modes,
input_normalization_params,
)
else:
self.normalize_inputs = nn.Identity()
@ -90,7 +92,8 @@ class SACPolicy(
ensemble=Ensemble(
[
CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
input_dim=encoder_critic.output_dim
+ config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
@ -104,7 +107,8 @@ class SACPolicy(
ensemble=Ensemble(
[
CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
input_dim=encoder_critic.output_dim
+ config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
@ -120,13 +124,17 @@ class SACPolicy(
self.actor = Policy(
encoder=encoder_actor,
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
network=MLP(
input_dim=encoder_actor.output_dim, **config.actor_network_kwargs
),
action_dim=config.output_shapes["action"][0],
encoder_is_shared=config.shared_encoder,
**config.policy_kwargs,
)
if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
config.target_entropy = (
-np.prod(config.output_shapes["action"][0]) / 2
) # (-dim(A)/2)
# TODO (azouitine): Handle the case where the temparameter is a fixed
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
@ -153,7 +161,11 @@ class SACPolicy(
return actions
def critic_forward(
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False, observation_features: Tensor | None = None
self,
observations: dict[str, Tensor],
actions: Tensor,
use_target: bool = False,
observation_features: Tensor | None = None,
) -> Tensor:
"""Forward pass through a critic network ensemble
@ -173,21 +185,37 @@ class SACPolicy(
def update_target_networks(self):
"""Update target networks with exponential moving average"""
for target_param, param in zip(
self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=False
self.critic_target.parameters(),
self.critic_ensemble.parameters(),
strict=False,
):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
def compute_loss_critic(self, observations, actions, rewards, next_observations, done, observation_features: Tensor | None = None, next_observation_features: Tensor | None = None) -> Tensor:
def compute_loss_critic(
self,
observations,
actions,
rewards,
next_observations,
done,
observation_features: Tensor | None = None,
next_observation_features: Tensor | None = None,
) -> Tensor:
temperature = self.log_alpha.exp().item()
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
next_action_preds, next_log_probs, _ = self.actor(
next_observations, next_observation_features
)
# 2- compute q targets
q_targets = self.critic_forward(
observations=next_observations, actions=next_action_preds, use_target=True, observation_features=next_observation_features
observations=next_observations,
actions=next_action_preds,
use_target=True,
observation_features=next_observation_features,
)
# subsample critics to prevent overfitting if use high UTD (update to date)
@ -204,7 +232,12 @@ class SACPolicy(
td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs
q_preds = self.critic_forward(observations, actions, use_target=False, observation_features=observation_features)
q_preds = self.critic_forward(
observations,
actions,
use_target=False,
observation_features=observation_features,
)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
@ -219,20 +252,31 @@ class SACPolicy(
).sum()
return critics_loss
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
def compute_loss_temperature(
self, observations, observation_features: Tensor | None = None
) -> Tensor:
"""Compute the temperature loss"""
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.actor(observations, observation_features)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
temperature_loss = (
-self.log_alpha.exp() * (log_probs + self.config.target_entropy)
).mean()
return temperature_loss
def compute_loss_actor(self, observations, observation_features: Tensor | None = None) -> Tensor:
def compute_loss_actor(
self, observations, observation_features: Tensor | None = None
) -> Tensor:
temperature = self.log_alpha.exp().item()
actions_pi, log_probs, _ = self.actor(observations, observation_features)
q_preds = self.critic_forward(observations, actions_pi, use_target=False, observation_features=observation_features)
q_preds = self.critic_forward(
observations,
actions_pi,
use_target=False,
observation_features=observation_features,
)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
@ -259,7 +303,11 @@ class MLP(nn.Module):
if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[0]))
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
layers.append(
activations
if isinstance(activations, nn.Module)
else getattr(nn, activations)()
)
# Rest of the layers
for i in range(1, len(hidden_dims)):
@ -270,7 +318,9 @@ class MLP(nn.Module):
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[i]))
layers.append(
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
activations
if isinstance(activations, nn.Module)
else getattr(nn, activations)()
)
self.net = nn.Sequential(*layers)
@ -381,7 +431,11 @@ class CriticEnsemble(nn.Module):
actions = self.output_normalization(actions)["action"]
actions = actions.to(device)
obs_enc = observation_features if observation_features is not None else (observations if self.encoder is None else self.encoder(observations))
obs_enc = (
observation_features
if observation_features is not None
else (observations if self.encoder is None else self.encoder(observations))
)
inputs = torch.cat([obs_enc, actions], dim=-1)
q_values = self.ensemble(inputs) # [num_critics, B, 1]
@ -445,7 +499,11 @@ class Policy(nn.Module):
observation_features: torch.Tensor | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists
obs_enc = observation_features if observation_features is not None else (observations if self.encoder is None else self.encoder(observations))
obs_enc = (
observation_features
if observation_features is not None
else (observations if self.encoder is None else self.encoder(observations))
)
# Get network outputs
outputs = self.network(obs_enc)
@ -454,11 +512,15 @@ class Policy(nn.Module):
# Compute standard deviations
if self.fixed_std is None:
log_std = self.std_layer(outputs)
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
assert not torch.isnan(
log_std
).any(), "[ERROR] log_std became NaN after std_layer!"
if self.use_tanh_squash:
log_std = torch.tanh(log_std)
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
log_std = self.log_std_min + 0.5 * (
self.log_std_max - self.log_std_min
) * (log_std + 1.0)
else:
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
else:
@ -471,7 +533,9 @@ class Policy(nn.Module):
if self.use_tanh_squash:
actions = torch.tanh(x_t)
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
log_probs -= torch.log(
(1 - actions.pow(2)) + 1e-6
) # Adjust log-probs for Tanh
else:
actions = x_t # No Tanh; raw Gaussian sample
@ -518,12 +582,15 @@ class SACObservationEncoder(nn.Module):
freeze_image_encoder(self.image_enc_layers)
else:
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
self.all_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
self.all_image_keys = [
k for k in config.input_shapes if k.startswith("observation.image")
]
if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential(
nn.Linear(
in_features=config.input_shapes["observation.state"][0], out_features=config.latent_dim
in_features=config.input_shapes["observation.state"][0],
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
@ -544,7 +611,9 @@ class SACObservationEncoder(nn.Module):
self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
self.aggregation_layer = nn.Linear(
in_features=self.aggregation_size, out_features=config.latent_dim
)
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
@ -557,13 +626,19 @@ class SACObservationEncoder(nn.Module):
obs_dict = self.input_normalization(obs_dict)
# Batch all images along the batch dimension, then encode them.
if len(self.all_image_keys) > 0:
images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0)
images_batched = torch.cat(
[obs_dict[key] for key in self.all_image_keys], dim=0
)
images_batched = self.image_enc_layers(images_batched)
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
embeddings_chunks = torch.chunk(
images_batched, dim=0, chunks=len(self.all_image_keys)
)
feat.extend(embeddings_chunks)
if "observation.environment_state" in self.config.input_shapes:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
feat.append(
self.env_state_enc_layers(obs_dict["observation.environment_state"])
)
if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
@ -631,7 +706,9 @@ class PretrainedImageEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
self.image_enc_layers, self.image_enc_out_shape = (
self._load_pretrained_vision_encoder(config)
)
self.image_enc_proj = nn.Sequential(
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
@ -642,15 +719,21 @@ class PretrainedImageEncoder(nn.Module):
"""Set up CNN encoder"""
from transformers import AutoModel
self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True)
self.image_enc_layers = AutoModel.from_pretrained(
config.vision_encoder_name, trust_remote_code=True
)
# self.image_enc_layers.pooler = Identity()
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[
-1
] # Last channel dimension
elif hasattr(self.image_enc_layers, "fc"):
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
else:
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
raise ValueError(
"Unsupported vision encoder architecture, make sure you are using a CNN"
)
return self.image_enc_layers, self.image_enc_out_shape
def forward(self, x):
@ -673,7 +756,7 @@ def orthogonal_init():
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
super().__init__()
def forward(self, x):
return x
@ -701,7 +784,9 @@ class Ensemble(nn.Module):
return self.module(*args, **kwargs)
def forward(self, *args, **kwargs):
return torch.vmap(self._call, (0, None), randomness="different")(self.params, *args, **kwargs)
return torch.vmap(self._call, (0, None), randomness="different")(
self.params, *args, **kwargs
)
def __repr__(self):
return f"Vectorized {len(self)}x " + self._repr
@ -710,7 +795,9 @@ class Ensemble(nn.Module):
# TODO (azouitine): I think in our case this function is not usefull we should remove it
# after some investigation
# borrowed from tdmpc
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
def flatten_forward_unflatten(
fn: Callable[[Tensor], Tensor], image_tensor: Tensor
) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor.
Args:
@ -736,7 +823,9 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
for key, value in inner_dict.items():
converted_params[outer_key][key] = torch.tensor(value)
if "image" in outer_key:
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
converted_params[outer_key][key] = converted_params[outer_key][
key
].view(3, 1, 1)
return converted_params

View File

@ -183,7 +183,9 @@ class TDMPCConfig(PreTrainedConfig):
"If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1."
)
if not self.use_mpc:
raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.")
raise ValueError(
"If `n_action_steps > 1`, `use_mpc` must be set to `True`."
)
if self.n_action_steps > self.horizon:
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")

View File

@ -100,7 +100,9 @@ class TDMPCPolicy(PreTrainedPolicy):
"""
self._queues = {
"observation.state": deque(maxlen=1),
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
"action": deque(
maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)
),
}
if self.config.image_features:
self._queues["observation.image"] = deque(maxlen=1)
@ -189,7 +191,11 @@ class TDMPCPolicy(PreTrainedPolicy):
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
# trajectories.
z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
z = einops.repeat(
z,
"b d -> n b d",
n=self.config.n_gaussian_samples + self.config.n_pi_samples,
)
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
# algorithm.
@ -211,35 +217,47 @@ class TDMPCPolicy(PreTrainedPolicy):
self.config.action_feature.shape[0],
device=std.device,
)
gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1)
gaussian_actions = torch.clamp(
mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1
)
# Compute elite actions.
actions = torch.cat([gaussian_actions, pi_actions], dim=1)
value = self.estimate_value(z, actions).nan_to_num_(0)
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch)
elite_idxs = torch.topk(
value, self.config.n_elites, dim=0
).indices # (n_elites, batch)
elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch)
# (horizon, n_elites, batch, action_dim)
elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1)
elite_actions = actions.take_along_dim(
einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1
)
# Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites.
max_value = elite_value.max(0, keepdim=True)[0] # (1, batch)
# The weighting is a softmax over trajectory values. Note that this is not the same as the usage
# of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This
# makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²).
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
score = torch.exp(
self.config.elite_weighting_temperature * (elite_value - max_value)
)
score /= score.sum(axis=0, keepdim=True)
# (horizon, batch, action_dim)
_mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1)
_mean = torch.sum(
einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1
)
_std = torch.sqrt(
torch.sum(
einops.rearrange(score, "n b -> n b 1")
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2,
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d"))
** 2,
dim=1,
)
)
# Update mean with an exponential moving average, and std with a direct replacement.
mean = (
self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean
self.config.gaussian_mean_momentum * mean
+ (1 - self.config.gaussian_mean_momentum) * _mean
)
std = _std.clamp_(self.config.min_std, self.config.max_std)
@ -248,7 +266,9 @@ class TDMPCPolicy(PreTrainedPolicy):
# Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax
# scores from the last iteration.
actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)]
actions = elite_actions[
:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)
]
return actions
@ -271,7 +291,8 @@ class TDMPCPolicy(PreTrainedPolicy):
# of the FOWM paper.
if self.config.uncertainty_regularizer_coeff > 0:
regularization = -(
self.config.uncertainty_regularizer_coeff * self.model.Qs(z, actions[t]).std(0)
self.config.uncertainty_regularizer_coeff
* self.model.Qs(z, actions[t]).std(0)
)
else:
regularization = 0
@ -291,15 +312,22 @@ class TDMPCPolicy(PreTrainedPolicy):
if self.config.q_ensemble_size > 2:
G += (
running_discount
* torch.min(terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0)[
0
]
* torch.min(
terminal_values[
torch.randint(0, self.config.q_ensemble_size, size=(2,))
],
dim=0,
)[0]
)
else:
G += running_discount * torch.min(terminal_values, dim=0)[0]
# Finally, also regularize the terminal value.
if self.config.uncertainty_regularizer_coeff > 0:
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
G -= (
running_discount
* self.config.uncertainty_regularizer_coeff
* terminal_values.std(0)
)
return G
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
@ -329,7 +357,10 @@ class TDMPCPolicy(PreTrainedPolicy):
# Apply random image augmentations.
if self.config.image_features and self.config.max_random_shift_ratio > 0:
observations["observation.image"] = flatten_forward_unflatten(
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
partial(
random_shifts_aug,
max_random_shift_ratio=self.config.max_random_shift_ratio,
),
observations["observation.image"],
)
@ -347,14 +378,20 @@ class TDMPCPolicy(PreTrainedPolicy):
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
# gives us a next `z`.
batch_size = batch["index"].shape[0]
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
z_preds = torch.empty(
horizon + 1, batch_size, self.config.latent_dim, device=device
)
z_preds[0] = self.model.encode(current_observation)
reward_preds = torch.empty_like(reward, device=device)
for t in range(horizon):
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t])
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(
z_preds[t], action[t]
)
# Compute Q and V value predictions based on the latent rollout.
q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch)
q_preds_ensemble = self.model.Qs(
z_preds[:-1], action
) # (ensemble, horizon, batch)
v_preds = self.model.V(z_preds[:-1])
info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()})
@ -368,10 +405,14 @@ class TDMPCPolicy(PreTrainedPolicy):
# actions (not actions estimated by π).
# Note: Here we do not use self.model_target, but self.model. This is to follow the original code
# and the FOWM paper.
q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations))
q_targets = reward + self.config.discount * self.model.V(
self.model.encode(next_observations)
)
# From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we
# are using them to compute loss for V.
v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True)
v_targets = self.model_target.Qs(
z_preds[:-1].detach(), action, return_min=True
)
# Compute losses.
# Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the
@ -414,7 +455,9 @@ class TDMPCPolicy(PreTrainedPolicy):
temporal_loss_coeffs
* F.mse_loss(
q_preds_ensemble,
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
einops.repeat(
q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]
),
reduction="none",
).sum(0) # sum over ensemble
# `q_preds_ensemble` depends on the first observation and the actions.
@ -452,12 +495,14 @@ class TDMPCPolicy(PreTrainedPolicy):
z_preds = z_preds.detach()
# Use stopgrad for the advantage calculation.
with torch.no_grad():
advantage = self.model_target.Qs(z_preds[:-1], action, return_min=True) - self.model.V(
z_preds[:-1]
)
advantage = self.model_target.Qs(
z_preds[:-1], action, return_min=True
) - self.model.V(z_preds[:-1])
info["advantage"] = advantage[0]
# (t, b)
exp_advantage = torch.clamp(torch.exp(advantage * self.config.advantage_scaling), max=100.0)
exp_advantage = torch.clamp(
torch.exp(advantage * self.config.advantage_scaling), max=100.0
)
action_preds = self.model.pi(z_preds[:-1]) # (t, b, a)
# Calculate the MSE between the actions and the action predictions.
# Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation
@ -511,7 +556,9 @@ class TDMPCPolicy(PreTrainedPolicy):
# Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA
# update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
update_ema_parameters(
self.model_target, self.model, self.config.target_model_momentum
)
class TDMPCTOLD(nn.Module):
@ -598,7 +645,9 @@ class TDMPCTOLD(nn.Module):
"Sanity check. The last linear layer needs 0 initialization on weights."
)
nn.init.zeros_(m[-1].weight)
nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure
nn.init.zeros_(
m[-1].bias
) # this has already been done, but keep this line here for good measure
def encode(self, obs: dict[str, Tensor]) -> Tensor:
"""Encodes an observation into its latent representation."""
@ -702,11 +751,26 @@ class TDMPCObservationEncoder(nn.Module):
stride=2,
),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
nn.Conv2d(
config.image_encoder_hidden_dim,
config.image_encoder_hidden_dim,
5,
stride=2,
),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.Conv2d(
config.image_encoder_hidden_dim,
config.image_encoder_hidden_dim,
3,
stride=2,
),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.Conv2d(
config.image_encoder_hidden_dim,
config.image_encoder_hidden_dim,
3,
stride=2,
),
nn.ReLU(),
)
dummy_shape = (1, *next(iter(config.image_features.values())).shape)
@ -796,12 +860,17 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
"""Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param."""
for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True):
for (n_p_ema, p_ema), (n_p, p) in zip(
ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True
ema_module.named_parameters(recurse=False),
module.named_parameters(recurse=False),
strict=True,
):
assert n_p_ema == n_p, "Parameter names don't match for EMA model update"
if isinstance(p, dict):
raise RuntimeError("Dict parameter not supported")
if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad:
if (
isinstance(module, nn.modules.batchnorm._BatchNorm)
or not p.requires_grad
):
# Copy BatchNorm parameters, and non-trainable parameters directly.
p_ema.copy_(p.to(dtype=p_ema.dtype).data)
with torch.no_grad():
@ -809,7 +878,9 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha)
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
def flatten_forward_unflatten(
fn: Callable[[Tensor], Tensor], image_tensor: Tensor
) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor.
Args:

View File

@ -145,8 +145,14 @@ class VQBeTPolicy(PreTrainedPolicy):
)
if len(self._queues["action"]) == 0:
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
batch = {
k: torch.stack(list(self._queues[k]), dim=1)
for k in batch
if k in self._queues
}
actions = self.vqbet(batch, rollout=True)[
:, : self.config.action_chunk_size
]
# the dimension of returned action is (batch_size, action_chunk_size, action_dim)
actions = self.unnormalize_outputs({"action": actions})["action"]
@ -225,7 +231,9 @@ class SpatialSoftmax(nn.Module):
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and causes a small degradation in pc_success of pre-trained models.
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
pos_x, pos_y = np.meshgrid(
np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)
)
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
# register as buffer so it's moved to the correct device.
@ -339,7 +347,12 @@ class VQBeTModel(nn.Module):
num_tokens = self.config.n_action_pred_token + self.config.n_obs_steps - 1
self.register_buffer(
"select_target_actions_indices",
torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),
torch.row_stack(
[
torch.arange(i, i + self.config.action_chunk_size)
for i in range(num_tokens)
]
),
)
def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]:
@ -354,7 +367,11 @@ class VQBeTModel(nn.Module):
)
# Separate batch and sequence dims.
img_features = einops.rearrange(
img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images
img_features,
"(b s n) ... -> b s n ...",
b=batch_size,
s=n_obs_steps,
n=self.num_images,
)
# Arrange prior and current observation step tokens as shown in the class docstring.
@ -366,13 +383,19 @@ class VQBeTModel(nn.Module):
input_tokens.append(
self.state_projector(batch["observation.state"])
) # (batch, obs_step, projection dims)
input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps))
input_tokens.append(
einops.repeat(
self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps
)
)
# Interleave tokens by stacking and rearranging.
input_tokens = torch.stack(input_tokens, dim=2)
input_tokens = einops.rearrange(input_tokens, "b n t d -> b (n t) d")
len_additional_action_token = self.config.n_action_pred_token - 1
future_action_tokens = self.action_token.repeat(batch_size, len_additional_action_token, 1)
future_action_tokens = self.action_token.repeat(
batch_size, len_additional_action_token, 1
)
# add additional action query tokens for predicting future action chunks
input_tokens = torch.cat([input_tokens, future_action_tokens], dim=1)
@ -391,7 +414,11 @@ class VQBeTModel(nn.Module):
# Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional).
if len_additional_action_token > 0:
features = torch.cat(
[features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1
[
features[:, historical_act_pred_index],
features[:, -len_additional_action_token:],
],
dim=1,
)
else:
features = features[:, historical_act_pred_index]
@ -399,13 +426,15 @@ class VQBeTModel(nn.Module):
action_head_output = self.action_head(features)
# if rollout, VQ-BeT don't calculate loss
if rollout:
return action_head_output["predicted_action"][:, n_obs_steps - 1, :].reshape(
batch_size, self.config.action_chunk_size, -1
)
return action_head_output["predicted_action"][
:, n_obs_steps - 1, :
].reshape(batch_size, self.config.action_chunk_size, -1)
# else, it calculate overall loss (bin prediction loss, and offset loss)
else:
output = batch["action"][:, self.select_target_actions_indices]
loss = self.action_head.loss_fn(action_head_output, output, reduction="mean")
loss = self.action_head.loss_fn(
action_head_output, output, reduction="mean"
)
return action_head_output, loss
@ -440,7 +469,9 @@ class VQBeTHead(nn.Module):
else:
self.map_to_cbet_preds_bin = MLP(
in_channels=config.gpt_output_dim,
hidden_channels=[self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed],
hidden_channels=[
self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed
],
)
self.map_to_cbet_preds_offset = MLP(
in_channels=config.gpt_output_dim,
@ -467,7 +498,10 @@ class VQBeTHead(nn.Module):
loss, metric = self.vqvae_model.vqvae_forward(actions)
n_different_codes = sum(
[len(torch.unique(metric[2][:, i])) for i in range(self.vqvae_model.vqvae_num_layers)]
[
len(torch.unique(metric[2][:, i]))
for i in range(self.vqvae_model.vqvae_num_layers)
]
)
n_different_combinations = len(torch.unique(metric[2], dim=0))
recon_l1_error = metric[0].detach().cpu().item()
@ -514,7 +548,13 @@ class VQBeTHead(nn.Module):
cbet_secondary_logits = self.map_to_cbet_preds_secondary_bin(
torch.cat(
(x, F.one_hot(sampled_primary_centers, num_classes=self.config.vqvae_n_embed)),
(
x,
F.one_hot(
sampled_primary_centers,
num_classes=self.config.vqvae_n_embed,
),
),
axis=1,
)
)
@ -522,19 +562,29 @@ class VQBeTHead(nn.Module):
cbet_secondary_logits / self.config.bet_softmax_temperature, dim=-1
)
sampled_secondary_centers = einops.rearrange(
torch.multinomial(cbet_secondary_probs.view(-1, choices), num_samples=1),
torch.multinomial(
cbet_secondary_probs.view(-1, choices), num_samples=1
),
"(NT) 1 -> NT",
NT=NT,
)
sampled_centers = torch.stack((sampled_primary_centers, sampled_secondary_centers), axis=1)
cbet_logits = torch.stack([cbet_primary_logits, cbet_secondary_logits], dim=1)
sampled_centers = torch.stack(
(sampled_primary_centers, sampled_secondary_centers), axis=1
)
cbet_logits = torch.stack(
[cbet_primary_logits, cbet_secondary_logits], dim=1
)
# if self.config.sequentially_select is False, bin prediction head samples primary and secondary code at once.
else:
cbet_logits = self.map_to_cbet_preds_bin(x)
cbet_logits = einops.rearrange(
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
cbet_logits,
"(NT) (G C) -> (NT) G C",
G=self.vqvae_model.vqvae_num_layers,
)
cbet_probs = torch.softmax(
cbet_logits / self.config.bet_softmax_temperature, dim=-1
)
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
NT, G, choices = cbet_probs.shape
sampled_centers = einops.rearrange(
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
@ -554,9 +604,17 @@ class VQBeTHead(nn.Module):
sampled_offsets = sampled_offsets.sum(dim=1)
with torch.no_grad():
# Get the centroids (= vectors corresponding to the codes) of each layer to pass it through RVQ decoder
return_decoder_input = self.vqvae_model.get_embeddings_from_code(sampled_centers).clone().detach()
return_decoder_input = (
self.vqvae_model.get_embeddings_from_code(sampled_centers)
.clone()
.detach()
)
# pass the centroids through decoder to get actions.
decoded_action = self.vqvae_model.get_action_from_latent(return_decoder_input).clone().detach()
decoded_action = (
self.vqvae_model.get_action_from_latent(return_decoder_input)
.clone()
.detach()
)
# reshaped extracted offset to match with decoded centroids
sampled_offsets = einops.rearrange(
sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size
@ -605,7 +663,9 @@ class VQBeTHead(nn.Module):
# Figure out the loss for the actions.
# First, we need to find the closest cluster center for each ground truth action.
with torch.no_grad():
state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
state_vq, action_bins = self.vqvae_model.get_code(
action_seq
) # action_bins: NT, G
# Now we can compute the loss.
@ -628,8 +688,12 @@ class VQBeTHead(nn.Module):
+ cbet_loss2 * self.config.secondary_code_loss_weight
)
equal_primary_code_rate = torch.sum((action_bins[:, 0] == sampled_centers[:, 0]).int()) / (NT)
equal_secondary_code_rate = torch.sum((action_bins[:, 1] == sampled_centers[:, 1]).int()) / (NT)
equal_primary_code_rate = torch.sum(
(action_bins[:, 0] == sampled_centers[:, 0]).int()
) / (NT)
equal_secondary_code_rate = torch.sum(
(action_bins[:, 1] == sampled_centers[:, 1]).int()
) / (NT)
action_mse_error = torch.mean((action_seq - predicted_action) ** 2)
vq_action_error = torch.mean(torch.abs(action_seq - decoded_action))
@ -643,7 +707,9 @@ class VQBeTHead(nn.Module):
"classification_loss": cbet_loss.detach().cpu().item(),
"offset_loss": offset_loss.detach().cpu().item(),
"equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(),
"equal_secondary_code_rate": equal_secondary_code_rate.detach().cpu().item(),
"equal_secondary_code_rate": equal_secondary_code_rate.detach()
.cpu()
.item(),
"vq_action_error": vq_action_error.detach().cpu().item(),
"offset_action_error": offset_action_error.detach().cpu().item(),
"action_error_max": action_error_max.detach().cpu().item(),
@ -668,7 +734,9 @@ class VQBeTRgbEncoder(nn.Module):
# Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
self.maybe_random_crop = torchvision.transforms.RandomCrop(
config.crop_shape
)
else:
self.maybe_random_crop = self.center_crop
else:
@ -689,7 +757,9 @@ class VQBeTRgbEncoder(nn.Module):
self.backbone = _replace_submodules(
root_module=self.backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features // 16, num_channels=x.num_features
),
)
# Set up pooling and final layers.
@ -730,7 +800,9 @@ class VQBeTRgbEncoder(nn.Module):
def _replace_submodules(
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
root_module: nn.Module,
predicate: Callable[[nn.Module], bool],
func: Callable[[nn.Module], nn.Module],
) -> nn.Module:
"""
Args:
@ -743,7 +815,11 @@ def _replace_submodules(
if predicate(root_module):
return func(root_module)
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
replace_list = [
k.split(".")
for k, m in root_module.named_modules(remove_duplicate=True)
if predicate(m)
]
for *parents, k in replace_list:
parent_module = root_module
if len(parents) > 0:
@ -758,7 +834,9 @@ def _replace_submodules(
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
assert not any(
predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)
)
return root_module

View File

@ -123,9 +123,15 @@ class CausalSelfAttention(nn.Module):
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.gpt_hidden_dim, dim=2)
k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
1, 2
) # (B, nh, T, hs)
q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
1, 2
) # (B, nh, T, hs)
v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
1, 2
) # (B, nh, T, hs)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
@ -133,7 +139,9 @@ class CausalSelfAttention(nn.Module):
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
y = (
y.transpose(1, 2).contiguous().view(B, T, C)
) # re-assemble all head outputs side by side
# output projection
y = self.resid_dropout(self.c_proj(y))
@ -189,12 +197,16 @@ class GPT(nn.Module):
"ln_f": nn.LayerNorm(config.gpt_hidden_dim),
}
)
self.lm_head = nn.Linear(config.gpt_hidden_dim, config.gpt_output_dim, bias=False)
self.lm_head = nn.Linear(
config.gpt_hidden_dim, config.gpt_output_dim, bias=False
)
# init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
self.apply(self._init_weights)
for pn, p in self.named_parameters():
if pn.endswith("c_proj.weight"):
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer))
torch.nn.init.normal_(
p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer)
)
# report number of parameters
n_params = sum(p.numel() for p in self.parameters())
@ -208,11 +220,17 @@ class GPT(nn.Module):
)
# positional encodings that are added to the input embeddings
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(
0
) # shape (1, t)
# forward the GPT model itself
tok_emb = self.transformer.wte(input) # token embeddings of shape (b, t, gpt_hidden_dim)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim)
tok_emb = self.transformer.wte(
input
) # token embeddings of shape (b, t, gpt_hidden_dim)
pos_emb = self.transformer.wpe(
pos
) # position embeddings of shape (1, t, gpt_hidden_dim)
x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h:
x = block(x)
@ -237,7 +255,9 @@ class GPT(nn.Module):
# but want to use a smaller block size for some smaller, simpler model
assert gpt_block_size <= self.config.gpt_block_size
self.config.gpt_block_size = gpt_block_size
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:gpt_block_size])
self.transformer.wpe.weight = nn.Parameter(
self.transformer.wpe.weight[:gpt_block_size]
)
for block in self.transformer.h:
block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size]
@ -270,7 +290,9 @@ class GPT(nn.Module):
param_dict = dict(self.named_parameters())
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
assert (
len(inter_params) == 0
), "parameters {} made it into both decay/no_decay sets!".format(
str(inter_params)
)
assert len(param_dict.keys() - union_params) == 0, (
@ -368,8 +390,12 @@ class ResidualVQ(nn.Module):
codebook_input_dim = codebook_dim * heads
requires_projection = codebook_input_dim != dim
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
self.project_in = (
nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
)
self.num_quantizers = num_quantizers
@ -377,7 +403,10 @@ class ResidualVQ(nn.Module):
self.layers = nn.ModuleList(
[
VectorQuantize(
dim=codebook_dim, codebook_dim=codebook_dim, accept_image_fmap=accept_image_fmap, **kwargs
dim=codebook_dim,
codebook_dim=codebook_dim,
accept_image_fmap=accept_image_fmap,
**kwargs,
)
for _ in range(num_quantizers)
]
@ -448,7 +477,9 @@ class ResidualVQ(nn.Module):
return all_codes
def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp=None):
def forward(
self, x, indices=None, return_all_codes=False, sample_codebook_temp=None
):
"""
For given input tensor x, this function will return the quantized output, the indices of the quantized output, and the loss.
First, the input tensor x is projected to the codebook dimension. Then, the input tensor x is passed through Nq layers of VectorQuantize.
@ -477,13 +508,17 @@ class ResidualVQ(nn.Module):
)
ce_losses = []
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
should_quantize_dropout = (
self.training and self.quantize_dropout and not return_loss
)
# sample a layer index at which to dropout further residual quantization
# also prepare null indices and loss
if should_quantize_dropout:
rand_quantize_dropout_index = randrange(self.quantize_dropout_cutoff_index, num_quant)
rand_quantize_dropout_index = randrange(
self.quantize_dropout_cutoff_index, num_quant
)
if quant_dropout_multiple_of != 1:
rand_quantize_dropout_index = (
@ -492,14 +527,23 @@ class ResidualVQ(nn.Module):
- 1
)
null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2])
null_indices = torch.full(null_indices_shape, -1.0, device=device, dtype=torch.long)
null_indices_shape = (
(x.shape[0], *x.shape[-2:])
if self.accept_image_fmap
else tuple(x.shape[:2])
)
null_indices = torch.full(
null_indices_shape, -1.0, device=device, dtype=torch.long
)
null_loss = torch.full((1,), 0.0, device=device, dtype=x.dtype)
# go through the layers
for quantizer_index, layer in enumerate(self.layers):
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
if (
should_quantize_dropout
and quantizer_index > rand_quantize_dropout_index
):
all_indices.append(null_indices)
all_losses.append(null_loss)
continue
@ -539,7 +583,9 @@ class ResidualVQ(nn.Module):
# stack all losses and indices
all_losses, all_indices = map(partial(torch.stack, dim=-1), (all_losses, all_indices))
all_losses, all_indices = map(
partial(torch.stack, dim=-1), (all_losses, all_indices)
)
ret = (quantized_out, all_indices, all_losses)
@ -599,8 +645,12 @@ class VectorQuantize(nn.Module):
codebook_input_dim = codebook_dim * heads
requires_projection = codebook_input_dim != dim
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
self.project_in = (
nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
)
self.eps = eps
self.commitment_weight = commitment_weight
@ -614,10 +664,14 @@ class VectorQuantize(nn.Module):
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
assert not (ema_update and learnable_codebook), "learnable codebook not compatible with EMA update"
assert not (
ema_update and learnable_codebook
), "learnable codebook not compatible with EMA update"
assert 0 <= sync_update_v <= 1.0
assert not (sync_update_v > 0.0 and not learnable_codebook), "learnable codebook must be turned on"
assert not (
sync_update_v > 0.0 and not learnable_codebook
), "learnable codebook must be turned on"
self.sync_update_v = sync_update_v
@ -629,7 +683,9 @@ class VectorQuantize(nn.Module):
)
if sync_codebook is None:
sync_codebook = distributed.is_initialized() and distributed.get_world_size() > 1
sync_codebook = (
distributed.is_initialized() and distributed.get_world_size() > 1
)
codebook_kwargs = {
"dim": codebook_dim,
@ -794,11 +850,17 @@ class VectorQuantize(nn.Module):
# quantize again
quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)
quantize, embed_ind, distances = self._codebook(
x, **codebook_forward_kwargs
)
if self.training:
# determine code to use for commitment loss
maybe_detach = torch.detach if not self.learnable_codebook or freeze_codebook else identity
maybe_detach = (
torch.detach
if not self.learnable_codebook or freeze_codebook
else identity
)
commit_quantize = maybe_detach(quantize)
@ -808,7 +870,9 @@ class VectorQuantize(nn.Module):
if self.sync_update_v > 0.0:
# (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
quantize = quantize + self.sync_update_v * (quantize - quantize.detach())
quantize = quantize + self.sync_update_v * (
quantize - quantize.detach()
)
# function for calculating cross entropy loss to distance matrix
# used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss
@ -841,7 +905,9 @@ class VectorQuantize(nn.Module):
embed_ind = rearrange(embed_ind, "1 (b h) n -> b n h", h=heads)
if self.accept_image_fmap:
embed_ind = rearrange(embed_ind, "b (h w) ... -> b h w ...", h=height, w=width)
embed_ind = rearrange(
embed_ind, "b (h w) ... -> b h w ...", h=height, w=width
)
if only_one:
embed_ind = rearrange(embed_ind, "b 1 -> b")
@ -895,8 +961,12 @@ class VectorQuantize(nn.Module):
num_codes = codebook.shape[-2]
if (self.orthogonal_reg_max_codes is not None) and num_codes > self.orthogonal_reg_max_codes:
rand_ids = torch.randperm(num_codes, device=device)[: self.orthogonal_reg_max_codes]
if (
self.orthogonal_reg_max_codes is not None
) and num_codes > self.orthogonal_reg_max_codes:
rand_ids = torch.randperm(num_codes, device=device)[
: self.orthogonal_reg_max_codes
]
codebook = codebook[:, rand_ids]
orthogonal_reg_loss = orthogonal_loss_fn(codebook)
@ -928,7 +998,9 @@ class VectorQuantize(nn.Module):
# if masking, only return quantized for where mask has True
if mask is not None:
quantize = torch.where(rearrange(mask, "... -> ... 1"), quantize, orig_input)
quantize = torch.where(
rearrange(mask, "... -> ... 1"), quantize, orig_input
)
return quantize, embed_ind, loss
@ -1038,7 +1110,9 @@ def sample_vectors(samples, num):
def batched_sample_vectors(samples, num):
return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0)
return torch.stack(
[sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0
)
def pad_shape(shape, size, dim=0):
@ -1089,7 +1163,9 @@ def sample_vectors_distributed(local_samples, num):
all_num_samples = all_gather_sizes(local_samples, dim=0)
if rank == 0:
samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum())
samples_per_rank = sample_multinomial(
num, all_num_samples / all_num_samples.sum()
)
else:
samples_per_rank = torch.empty_like(all_num_samples)
@ -1202,7 +1278,9 @@ class EuclideanCodebook(nn.Module):
self.eps = eps
self.threshold_ema_dead_code = threshold_ema_dead_code
self.reset_cluster_size = (
reset_cluster_size if (reset_cluster_size is not None) else threshold_ema_dead_code
reset_cluster_size
if (reset_cluster_size is not None)
else threshold_ema_dead_code
)
assert callable(gumbel_sample)
@ -1213,8 +1291,14 @@ class EuclideanCodebook(nn.Module):
"kmeans init is not compatible with multiple codebooks in distributed environment for now"
)
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
self.sample_fn = (
sample_vectors_distributed
if use_ddp and sync_kmeans
else batched_sample_vectors
)
self.kmeans_all_reduce_fn = (
distributed.all_reduce if use_ddp and sync_kmeans else noop
)
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
self.register_buffer("initted", torch.Tensor([not kmeans_init]))
@ -1353,7 +1437,9 @@ class EuclideanCodebook(nn.Module):
distributed.all_reduce(variance_numer)
batch_variance = variance_numer / num_vectors
self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay)
self.update_with_decay(
"batch_variance", batch_variance, self.affine_param_batch_decay
)
def replace(self, batch_samples, batch_mask):
for ind, (samples, mask) in enumerate(
@ -1362,7 +1448,9 @@ class EuclideanCodebook(nn.Module):
if not torch.any(mask):
continue
sampled = self.sample_fn(rearrange(samples, "... -> 1 ..."), mask.sum().item())
sampled = self.sample_fn(
rearrange(samples, "... -> 1 ..."), mask.sum().item()
)
sampled = rearrange(sampled, "1 ... -> ...")
self.embed.data[ind][mask] = sampled
@ -1386,7 +1474,9 @@ class EuclideanCodebook(nn.Module):
def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
needs_codebook_dim = x.ndim < 4
sample_codebook_temp = (
sample_codebook_temp if (sample_codebook_temp is not None) else self.sample_codebook_temp
sample_codebook_temp
if (sample_codebook_temp is not None)
else self.sample_codebook_temp
)
x = x.float()
@ -1414,7 +1504,9 @@ class EuclideanCodebook(nn.Module):
if self.affine_param:
codebook_std = self.codebook_variance.clamp(min=1e-5).sqrt()
batch_std = self.batch_variance.clamp(min=1e-5).sqrt()
embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean
embed = (embed - self.codebook_mean) * (
batch_std / codebook_std
) + self.batch_mean
dist = -cdist(flatten, embed)
@ -1432,7 +1524,9 @@ class EuclideanCodebook(nn.Module):
if self.training and self.ema_update and not freeze_codebook:
if self.affine_param:
flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean
flatten = (flatten - self.batch_mean) * (
codebook_std / batch_std
) + self.codebook_mean
if mask is not None:
embed_onehot[~mask] = 0.0
@ -1455,7 +1549,9 @@ class EuclideanCodebook(nn.Module):
self.expire_codes_(x)
if needs_codebook_dim:
quantize, embed_ind = tuple(rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind))
quantize, embed_ind = tuple(
rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind)
)
dist = unpack_one(dist, ps, "h * d")

View File

@ -65,7 +65,9 @@ def save_image(img_array, serial_number, frame_index, images_dir):
img.save(str(path), quality=100)
logging.info(f"Saved image: {path}")
except Exception as e:
logging.error(f"Failed to save image for camera {serial_number} frame {frame_index}: {e}")
logging.error(
f"Failed to save image for camera {serial_number} frame {frame_index}: {e}"
)
def save_images_from_cameras(
@ -143,7 +145,9 @@ def save_images_from_cameras(
if time.perf_counter() - start_time > record_time_s:
break
print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
print(
f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}"
)
frame_index += 1
finally:
@ -251,7 +255,9 @@ class IntelRealSenseCamera:
f"Multiple {name} cameras have been detected. Please use their serial number to instantiate them."
)
name_to_serial_dict = {cam["name"]: cam["serial_number"] for cam in camera_infos}
name_to_serial_dict = {
cam["name"]: cam["serial_number"] for cam in camera_infos
}
cam_sn = name_to_serial_dict[name]
return cam_sn
@ -272,13 +278,17 @@ class IntelRealSenseCamera:
if self.fps and self.width and self.height:
# TODO(rcadene): can we set rgb8 directly?
config.enable_stream(rs.stream.color, self.width, self.height, rs.format.rgb8, self.fps)
config.enable_stream(
rs.stream.color, self.width, self.height, rs.format.rgb8, self.fps
)
else:
config.enable_stream(rs.stream.color)
if self.use_depth:
if self.fps and self.width and self.height:
config.enable_stream(rs.stream.depth, self.width, self.height, rs.format.z16, self.fps)
config.enable_stream(
rs.stream.depth, self.width, self.height, rs.format.z16, self.fps
)
else:
config.enable_stream(rs.stream.depth)
@ -311,7 +321,9 @@ class IntelRealSenseCamera:
actual_height = color_profile.height()
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
if self.fps is not None and not math.isclose(
self.fps, actual_fps, rel_tol=1e-3
):
# Using `OSError` since it's a broad that encompasses issues related to device communication
raise OSError(
f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
@ -331,7 +343,9 @@ class IntelRealSenseCamera:
self.is_connected = True
def read(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
def read(
self, temporary_color: str | None = None
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
"""Read a frame from the camera returned in the format height x width x channels (e.g. 480 x 640 x 3)
of type `np.uint8`, contrarily to the pytorch format which is float channel first.
@ -358,11 +372,15 @@ class IntelRealSenseCamera:
color_frame = frame.get_color_frame()
if not color_frame:
raise OSError(f"Can't capture color image from IntelRealSenseCamera({self.serial_number}).")
raise OSError(
f"Can't capture color image from IntelRealSenseCamera({self.serial_number})."
)
color_image = np.asanyarray(color_frame.get_data())
requested_color_mode = self.color_mode if temporary_color is None else temporary_color
requested_color_mode = (
self.color_mode if temporary_color is None else temporary_color
)
if requested_color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided."
@ -390,7 +408,9 @@ class IntelRealSenseCamera:
if self.use_depth:
depth_frame = frame.get_depth_frame()
if not depth_frame:
raise OSError(f"Can't capture depth image from IntelRealSenseCamera({self.serial_number}).")
raise OSError(
f"Can't capture depth image from IntelRealSenseCamera({self.serial_number})."
)
depth_map = np.asanyarray(depth_frame.get_data())
@ -432,7 +452,9 @@ class IntelRealSenseCamera:
# TODO(rcadene, aliberts): intelrealsense has diverged compared to opencv over here
num_tries += 1
time.sleep(1 / self.fps)
if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()):
if num_tries > self.fps and (
self.thread.ident is None or not self.thread.is_alive()
):
raise Exception(
"The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called."
)

View File

@ -31,10 +31,14 @@ from lerobot.common.utils.utils import capture_timestamp_utc
MAX_OPENCV_INDEX = 60
def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
def find_cameras(
raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False
) -> list[dict]:
cameras = []
if platform.system() == "Linux":
print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports")
print(
"Linux detected. Finding available camera indices through scanning '/dev/video*' ports"
)
possible_ports = [str(port) for port in Path("/dev").glob("video*")]
ports = _find_cameras(possible_ports, mock=mock)
for port in ports:
@ -166,7 +170,9 @@ def save_images_from_cameras(
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
print(
f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}"
)
if time.perf_counter() - start_time > record_time_s:
break
@ -223,7 +229,9 @@ class OpenCVCamera:
if platform.system() == "Linux":
if isinstance(self.camera_index, int):
self.port = Path(f"/dev/video{self.camera_index}")
elif isinstance(self.camera_index, str) and is_valid_unix_path(self.camera_index):
elif isinstance(self.camera_index, str) and is_valid_unix_path(
self.camera_index
):
self.port = Path(self.camera_index)
# Retrieve the camera index from a potentially symlinked path
self.camera_index = get_camera_index_from_unix_port(self.port)
@ -260,7 +268,9 @@ class OpenCVCamera:
def connect(self):
if self.is_connected:
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
raise RobotDeviceAlreadyConnectedError(
f"OpenCVCamera({self.camera_index}) is already connected."
)
if self.mock:
import tests.mock_cv2 as cv2
@ -271,7 +281,11 @@ class OpenCVCamera:
# when other threads are used to save the images.
cv2.setNumThreads(1)
camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index
camera_idx = (
f"/dev/video{self.camera_index}"
if platform.system() == "Linux"
else self.camera_index
)
# First create a temporary camera trying to access `camera_index`,
# and verify it is a valid camera by calling `isOpened`.
tmp_camera = cv2.VideoCapture(camera_idx)
@ -311,16 +325,22 @@ class OpenCVCamera:
actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
if self.fps is not None and not math.isclose(
self.fps, actual_fps, rel_tol=1e-3
):
# Using `OSError` since it's a broad that encompasses issues related to device communication
raise OSError(
f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}."
)
if self.width is not None and not math.isclose(self.width, actual_width, rel_tol=1e-3):
if self.width is not None and not math.isclose(
self.width, actual_width, rel_tol=1e-3
):
raise OSError(
f"Can't set {self.width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}."
)
if self.height is not None and not math.isclose(self.height, actual_height, rel_tol=1e-3):
if self.height is not None and not math.isclose(
self.height, actual_height, rel_tol=1e-3
):
raise OSError(
f"Can't set {self.height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}."
)
@ -350,7 +370,9 @@ class OpenCVCamera:
if not ret:
raise OSError(f"Can't capture color image from camera {self.camera_index}.")
requested_color_mode = self.color_mode if temporary_color_mode is None else temporary_color_mode
requested_color_mode = (
self.color_mode if temporary_color_mode is None else temporary_color_mode
)
if requested_color_mode not in ["rgb", "bgr"]:
raise ValueError(

View File

@ -25,7 +25,9 @@ from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import get_safe_torch_device, has_method
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
def log_control_info(
robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None
):
log_items = []
if episode_index is not None:
log_items.append(f"ep:{episode_index}")
@ -92,7 +94,9 @@ def predict_action(observation, policy, device, use_amp):
observation = copy(observation)
with (
torch.inference_mode(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
torch.autocast(device_type=device.type)
if device.type == "cuda" and use_amp
else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
@ -148,7 +152,9 @@ def init_keyboard_listener(assign_rewards=False):
print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True
elif key == keyboard.Key.left:
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
print(
"Left arrow key pressed. Exiting loop and rerecord the last episode..."
)
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
@ -273,7 +279,9 @@ def control_loop(
frame = {**observation, **action}
if "next.reward" in events:
frame["next.reward"] = events["next.reward"]
frame["next.done"] = (events["next.reward"] == 1) or (events["exit_early"])
frame["next.done"] = (events["next.reward"] == 1) or (
events["exit_early"]
)
dataset.add_frame(frame)
# if frame["next.done"]:
@ -282,7 +290,9 @@ def control_loop(
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.imshow(
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
)
cv2.waitKey(1)
if fps is not None:
@ -360,7 +370,11 @@ def sanity_check_dataset_name(repo_id, policy_cfg):
def sanity_check_dataset_robot_compatibility(
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool, extra_features: dict = None
dataset: LeRobotDataset,
robot: Robot,
fps: int,
use_videos: bool,
extra_features: dict = None,
) -> None:
features_from_robot = get_features_from_robot(robot, use_videos)
if extra_features is not None:
@ -374,11 +388,14 @@ def sanity_check_dataset_robot_compatibility(
mismatches = []
for field, dataset_value, present_value in fields:
diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"])
diff = DeepDiff(
dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"]
)
if diff:
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
if mismatches:
raise ValueError(
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
"Dataset metadata compatibility check failed with mismatches:\n"
+ "\n".join(mismatches)
)

View File

@ -144,7 +144,9 @@ NUM_READ_RETRY = 10
NUM_WRITE_RETRY = 10
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
def convert_degrees_to_steps(
degrees: float | np.ndarray, models: str | list[str]
) -> np.ndarray:
"""This function converts the degree range to the step range for indicating motors rotation.
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
@ -370,7 +372,9 @@ class DynamixelMotorsBus:
indices = []
for idx in tqdm.tqdm(possible_ids):
try:
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
present_idx = self.read_with_motor_ids(
self.motor_models, [idx], "ID", num_retry=num_retry
)[0]
except ConnectionError:
continue
@ -386,7 +390,9 @@ class DynamixelMotorsBus:
def set_bus_baudrate(self, baudrate):
present_bus_baudrate = self.port_handler.getBaudRate()
if present_bus_baudrate != baudrate:
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
print(
f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}."
)
self.port_handler.setBaudRate(baudrate)
if self.port_handler.getBaudRate() != baudrate:
@ -407,7 +413,9 @@ class DynamixelMotorsBus:
def set_calibration(self, calibration: dict[str, list]):
self.calibration = calibration
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
def apply_calibration_autocorrect(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""This function applies the calibration, automatically detects out of range errors for motors values and attempts to correct.
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
@ -420,7 +428,9 @@ class DynamixelMotorsBus:
values = self.apply_calibration(values, motor_names)
return values
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
def apply_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
a "zero position" at 0 degree.
@ -495,7 +505,9 @@ class DynamixelMotorsBus:
return values
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
def autocorrect_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""This function automatically detects issues with values of motors after calibration, and correct for these issues.
Some motors might have values outside of expected maximum bounds after calibration.
@ -537,15 +549,23 @@ class DynamixelMotorsBus:
values[i] *= -1
# Convert from initial range to range [-180, 180] degrees
calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
calib_val = (
(values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
)
in_range = (calib_val > LOWER_BOUND_DEGREE) and (
calib_val < UPPER_BOUND_DEGREE
)
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
# (- (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= ((resolution // 2) - values[i] - homing_offset) / resolution
low_factor = (-(resolution // 2) - values[i] - homing_offset) / resolution
upp_factor = ((resolution // 2) - values[i] - homing_offset) / resolution
low_factor = (
-(resolution // 2) - values[i] - homing_offset
) / resolution
upp_factor = (
(resolution // 2) - values[i] - homing_offset
) / resolution
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
start_pos = self.calibration["start_pos"][calib_idx]
@ -553,7 +573,9 @@ class DynamixelMotorsBus:
# Convert from initial range to range [0, 100] in %
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
in_range = (calib_val > LOWER_BOUND_LINEAR) and (
calib_val < UPPER_BOUND_LINEAR
)
# Solve this inequality to find the factor to shift the range into [0, 100] %
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
@ -569,19 +591,27 @@ class DynamixelMotorsBus:
factor = math.ceil(low_factor)
if factor > upp_factor:
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
raise ValueError(
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
else:
factor = math.ceil(upp_factor)
if factor > low_factor:
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
raise ValueError(
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
out_of_range_str = (
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
in_range_str = (
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
logging.warning(
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
@ -591,7 +621,9 @@ class DynamixelMotorsBus:
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
self.calibration["homing_offset"][calib_idx] += resolution * factor
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
def revert_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""Inverse of `apply_calibration`."""
if motor_names is None:
motor_names = self.motor_names
@ -630,7 +662,9 @@ class DynamixelMotorsBus:
values = np.round(values).astype(np.int32)
return values
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
def read_with_motor_ids(
self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY
):
if self.mock:
import tests.mock_dynamixel_sdk as dxl
else:
@ -732,7 +766,9 @@ class DynamixelMotorsBus:
values = self.apply_calibration_autocorrect(values, motor_names)
# log the number of seconds it took to read the data from the motors
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
delta_ts_name = get_log_name(
"delta_timestamp_s", "read", data_name, motor_names
)
self.logs[delta_ts_name] = time.perf_counter() - start_time
# log the utc time at which the data was received
@ -741,7 +777,9 @@ class DynamixelMotorsBus:
return values
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
def write_with_motor_ids(
self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY
):
if self.mock:
import tests.mock_dynamixel_sdk as dxl
else:
@ -770,7 +808,12 @@ class DynamixelMotorsBus:
f"{self.packet_handler.getTxRxResult(comm)}"
)
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
def write(
self,
data_name,
values: int | float | np.ndarray,
motor_names: str | list[str] | None = None,
):
if not self.is_connected:
raise RobotDeviceNotConnectedError(
f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
@ -831,7 +874,9 @@ class DynamixelMotorsBus:
)
# log the number of seconds it took to write the data to the motors
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
delta_ts_name = get_log_name(
"delta_timestamp_s", "write", data_name, motor_names
)
self.logs[delta_ts_name] = time.perf_counter() - start_time
# TODO(rcadene): should we log the time before sending the write command?

View File

@ -123,7 +123,9 @@ NUM_READ_RETRY = 20
NUM_WRITE_RETRY = 20
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
def convert_degrees_to_steps(
degrees: float | np.ndarray, models: str | list[str]
) -> np.ndarray:
"""This function converts the degree range to the step range for indicating motors rotation.
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
@ -351,7 +353,9 @@ class FeetechMotorsBus:
indices = []
for idx in tqdm.tqdm(possible_ids):
try:
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
present_idx = self.read_with_motor_ids(
self.motor_models, [idx], "ID", num_retry=num_retry
)[0]
except ConnectionError:
continue
@ -367,7 +371,9 @@ class FeetechMotorsBus:
def set_bus_baudrate(self, baudrate):
present_bus_baudrate = self.port_handler.getBaudRate()
if present_bus_baudrate != baudrate:
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
print(
f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}."
)
self.port_handler.setBaudRate(baudrate)
if self.port_handler.getBaudRate() != baudrate:
@ -388,7 +394,9 @@ class FeetechMotorsBus:
def set_calibration(self, calibration: dict[str, list]):
self.calibration = calibration
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
def apply_calibration_autocorrect(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""This function apply the calibration, automatically detects out of range errors for motors values and attempt to correct.
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
@ -401,7 +409,9 @@ class FeetechMotorsBus:
values = self.apply_calibration(values, motor_names)
return values
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
def apply_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
a "zero position" at 0 degree.
@ -475,7 +485,9 @@ class FeetechMotorsBus:
return values
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
def autocorrect_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""This function automatically detects issues with values of motors after calibration, and correct for these issues.
Some motors might have values outside of expected maximum bounds after calibration.
@ -514,18 +526,26 @@ class FeetechMotorsBus:
values[i] *= -1
# Convert from initial range to range [-180, 180] degrees
calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
calib_val = (
(values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
)
in_range = (calib_val > LOWER_BOUND_DEGREE) and (
calib_val < UPPER_BOUND_DEGREE
)
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
# (- HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= (HALF_TURN_DEGREE / 180 * (resolution // 2) - values[i] - homing_offset) / resolution
low_factor = (
-HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
-HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2)
- values[i]
- homing_offset
) / resolution
upp_factor = (
HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2)
- values[i]
- homing_offset
) / resolution
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
@ -534,7 +554,9 @@ class FeetechMotorsBus:
# Convert from initial range to range [0, 100] in %
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
in_range = (calib_val > LOWER_BOUND_LINEAR) and (
calib_val < UPPER_BOUND_LINEAR
)
# Solve this inequality to find the factor to shift the range into [0, 100] %
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
@ -550,19 +572,27 @@ class FeetechMotorsBus:
factor = math.ceil(low_factor)
if factor > upp_factor:
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
raise ValueError(
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
else:
factor = math.ceil(upp_factor)
if factor > low_factor:
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
raise ValueError(
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
out_of_range_str = (
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
in_range_str = (
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
logging.warning(
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
@ -572,7 +602,9 @@ class FeetechMotorsBus:
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
self.calibration["homing_offset"][calib_idx] += resolution * factor
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
def revert_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""Inverse of `apply_calibration`."""
if motor_names is None:
motor_names = self.motor_names
@ -648,7 +680,9 @@ class FeetechMotorsBus:
return values
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
def read_with_motor_ids(
self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY
):
if self.mock:
import tests.mock_scservo_sdk as scs
else:
@ -757,7 +791,9 @@ class FeetechMotorsBus:
values = self.apply_calibration_autocorrect(values, motor_names)
# log the number of seconds it took to read the data from the motors
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
delta_ts_name = get_log_name(
"delta_timestamp_s", "read", data_name, motor_names
)
self.logs[delta_ts_name] = time.perf_counter() - start_time
# log the utc time at which the data was received
@ -766,7 +802,9 @@ class FeetechMotorsBus:
return values
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
def write_with_motor_ids(
self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY
):
if self.mock:
import tests.mock_scservo_sdk as scs
else:
@ -795,7 +833,12 @@ class FeetechMotorsBus:
f"{self.packet_handler.getTxRxResult(comm)}"
)
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
def write(
self,
data_name,
values: int | float | np.ndarray,
motor_names: str | list[str] | None = None,
):
if not self.is_connected:
raise RobotDeviceNotConnectedError(
f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
@ -856,7 +899,9 @@ class FeetechMotorsBus:
)
# log the number of seconds it took to write the data to the motors
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
delta_ts_name = get_log_name(
"delta_timestamp_s", "write", data_name, motor_names
)
self.logs[delta_ts_name] = time.perf_counter() - start_time
# TODO(rcadene): should we log the time before sending the write command?

View File

@ -10,9 +10,7 @@ from lerobot.common.robot_devices.motors.dynamixel import (
)
from lerobot.common.robot_devices.motors.utils import MotorsBus
URL_TEMPLATE = (
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
)
URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
# The following positions are provided in nominal degree range ]-180, +180[
# For more info on these constants, see comments in the code where they get used.
@ -23,7 +21,9 @@ ROTATED_POSITION_DEGREE = 90
def assert_drive_mode(drive_mode):
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
if not np.all(np.isin(drive_mode, [0, 1])):
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
raise ValueError(
f"`drive_mode` contains values other than 0 or 1: ({drive_mode})"
)
def apply_drive_mode(position, drive_mode):
@ -64,12 +64,16 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
```
"""
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError("To run calibration, the torque must be disabled on all motors.")
raise ValueError(
"To run calibration, the torque must be disabled on all motors."
)
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to zero position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
print(
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")
)
input("Press Enter to continue...")
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
@ -90,10 +94,15 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view
# of the previous motor in the kinetic chain.
print("\nMove arm to rotated target position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
print(
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")
)
input("Press Enter to continue...")
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
rotated_target_pos = convert_degrees_to_steps(
ROTATED_POSITION_DEGREE, arm.motor_models
)
# Find drive mode by rotating each motor by a quarter of a turn.
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
@ -102,11 +111,15 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
# Re-compute homing offset to take into account drive mode
rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode)
rotated_nearest_pos = compute_nearest_rounded_position(rotated_drived_pos, arm.motor_models)
rotated_nearest_pos = compute_nearest_rounded_position(
rotated_drived_pos, arm.motor_models
)
homing_offset = rotated_target_pos - rotated_nearest_pos
print("\nMove arm to rest position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest"))
print(
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")
)
input("Press Enter to continue...")
print()

View File

@ -12,9 +12,7 @@ from lerobot.common.robot_devices.motors.feetech import (
)
from lerobot.common.robot_devices.motors.utils import MotorsBus
URL_TEMPLATE = (
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
)
URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
# The following positions are provided in nominal degree range ]-180, +180[
# For more info on these constants, see comments in the code where they get used.
@ -25,7 +23,9 @@ ROTATED_POSITION_DEGREE = 90
def assert_drive_mode(drive_mode):
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
if not np.all(np.isin(drive_mode, [0, 1])):
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
raise ValueError(
f"`drive_mode` contains values other than 0 or 1: ({drive_mode})"
)
def apply_drive_mode(position, drive_mode):
@ -126,7 +126,9 @@ def apply_offset(calib, offset):
return calib
def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
def run_arm_auto_calibration(
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
if robot_type == "so100":
return run_arm_auto_calibration_so100(arm, robot_type, arm_name, arm_type)
elif robot_type == "moss":
@ -135,18 +137,27 @@ def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm
raise ValueError(robot_type)
def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
def run_arm_auto_calibration_so100(
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError("To run calibration, the torque must be disabled on all motors.")
raise ValueError(
"To run calibration, the torque must be disabled on all motors."
)
if not (robot_type == "so100" and arm_type == "follower"):
raise NotImplementedError("Auto calibration only supports the follower of so100 arms for now.")
raise NotImplementedError(
"Auto calibration only supports the follower of so100 arms for now."
)
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to initial position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
print(
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")
)
input("Press Enter to continue...")
# Lower the acceleration of the motors (in [0,254])
@ -193,11 +204,16 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
print("Calibrate elbow_flex")
calib["elbow_flex"] = move_to_calibrate(
arm, "elbow_flex", positive_first=False, in_between_move_hook=in_between_move_hook
arm,
"elbow_flex",
positive_first=False,
in_between_move_hook=in_between_move_hook,
)
calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024)
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex")
arm.write(
"Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex"
)
time.sleep(1)
def in_between_move_hook():
@ -225,18 +241,30 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
}
arm.write("Goal_Position", list(positions.values()), list(positions.keys()))
arm.write("Goal_Position", round(calib["shoulder_lift"]["zero_pos"] - 1600), "shoulder_lift")
arm.write(
"Goal_Position",
round(calib["shoulder_lift"]["zero_pos"] - 1600),
"shoulder_lift",
)
time.sleep(2)
arm.write("Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex")
arm.write(
"Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex"
)
time.sleep(2)
arm.write("Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex")
arm.write(
"Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex"
)
time.sleep(2)
arm.write("Goal_Position", round(calib["gripper"]["end_pos"]), "gripper")
time.sleep(2)
print("Calibrate wrist_roll")
calib["wrist_roll"] = move_to_calibrate(
arm, "wrist_roll", invert_drive_mode=True, positive_first=False, while_move_hook=while_move_hook
arm,
"wrist_roll",
invert_drive_mode=True,
positive_first=False,
while_move_hook=while_move_hook,
)
arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"], "wrist_roll")
@ -246,7 +274,9 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex")
time.sleep(1)
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 2048, "elbow_flex")
arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift")
arm.write(
"Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift"
)
time.sleep(1)
arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan")
time.sleep(1)
@ -275,18 +305,27 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
return calib_dict
def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
def run_arm_auto_calibration_moss(
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError("To run calibration, the torque must be disabled on all motors.")
raise ValueError(
"To run calibration, the torque must be disabled on all motors."
)
if not (robot_type == "moss" and arm_type == "follower"):
raise NotImplementedError("Auto calibration only supports the follower of moss arms for now.")
raise NotImplementedError(
"Auto calibration only supports the follower of moss arms for now."
)
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to initial position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
print(
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")
)
input("Press Enter to continue...")
# Lower the acceleration of the motors (in [0,254])
@ -370,8 +409,12 @@ def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
time.sleep(1)
arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift")
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex")
arm.write(
"Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift"
)
arm.write(
"Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex"
)
time.sleep(2)
calib_modes = []
@ -398,7 +441,9 @@ def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str
return calib_dict
def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
def run_arm_manual_calibration(
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
"""This function ensures that a neural network trained on data collected on a given robot
can work on another robot. For instance before calibration, setting a same goal position
for each motor of two different robots will get two very different positions. But after calibration,
@ -421,12 +466,16 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a
```
"""
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError("To run calibration, the torque must be disabled on all motors.")
raise ValueError(
"To run calibration, the torque must be disabled on all motors."
)
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to zero position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
print(
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")
)
input("Press Enter to continue...")
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
@ -446,10 +495,15 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view
# of the previous motor in the kinetic chain.
print("\nMove arm to rotated target position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
print(
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")
)
input("Press Enter to continue...")
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
rotated_target_pos = convert_degrees_to_steps(
ROTATED_POSITION_DEGREE, arm.motor_models
)
# Find drive mode by rotating each motor by a quarter of a turn.
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
@ -461,7 +515,9 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a
homing_offset = rotated_target_pos - rotated_drived_pos
print("\nMove arm to rest position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest"))
print(
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")
)
input("Press Enter to continue...")
print()

View File

@ -17,11 +17,16 @@ from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig
from lerobot.common.robot_devices.robots.utils import get_arm_id
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
)
def ensure_safe_goal_position(
goal_pos: torch.Tensor, present_pos: torch.Tensor, max_relative_target: float | list[float]
goal_pos: torch.Tensor,
present_pos: torch.Tensor,
max_relative_target: float | list[float],
):
# Cap relative action target magnitude for safety.
diff = goal_pos - present_pos
@ -263,7 +268,9 @@ class ManipulatorRobot:
# to squeeze the gripper and have it spring back to an open position on its own.
for name in self.leader_arms:
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper")
self.leader_arms[name].write(
"Goal_Position", self.config.gripper_open_degree, "gripper"
)
# Check both arms can be read
for name in self.follower_arms:
@ -295,18 +302,26 @@ class ManipulatorRobot:
print(f"Missing calibration file '{arm_calib_path}'")
if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
from lerobot.common.robot_devices.robots.dynamixel_calibration import run_arm_calibration
from lerobot.common.robot_devices.robots.dynamixel_calibration import (
run_arm_calibration,
)
calibration = run_arm_calibration(arm, self.robot_type, name, arm_type)
calibration = run_arm_calibration(
arm, self.robot_type, name, arm_type
)
elif self.robot_type in ["so100", "moss"]:
from lerobot.common.robot_devices.robots.feetech_calibration import (
run_arm_manual_calibration,
)
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
calibration = run_arm_manual_calibration(
arm, self.robot_type, name, arm_type
)
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
print(
f"Calibration is done! Saving calibration file '{arm_calib_path}'"
)
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
with open(arm_calib_path, "w") as f:
json.dump(calibration, f)
@ -325,13 +340,17 @@ class ManipulatorRobot:
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError("To run set robot preset, the torque must be disabled on all motors.")
raise ValueError(
"To run set robot preset, the torque must be disabled on all motors."
)
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't
# rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm,
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
all_motors_except_gripper = [name for name in arm.motor_names if name != "gripper"]
all_motors_except_gripper = [
name for name in arm.motor_names if name != "gripper"
]
if len(all_motors_except_gripper) > 0:
# 4 corresponds to Extended Position on Koch motors
arm.write("Operating_Mode", 4, all_motors_except_gripper)
@ -360,7 +379,9 @@ class ManipulatorRobot:
# Enable torque on the gripper of the leader arms, and move it to 45 degrees,
# so that we can use it as a trigger to close the gripper of the follower arms.
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper")
self.leader_arms[name].write(
"Goal_Position", self.config.gripper_open_degree, "gripper"
)
def set_aloha_robot_preset(self):
def set_shadow_(arm):
@ -390,11 +411,15 @@ class ManipulatorRobot:
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
all_motors_except_gripper = [
name for name in self.follower_arms[name].motor_names if name != "gripper"
name
for name in self.follower_arms[name].motor_names
if name != "gripper"
]
if len(all_motors_except_gripper) > 0:
# 4 corresponds to Extended Position on Aloha motors
self.follower_arms[name].write("Operating_Mode", 4, all_motors_except_gripper)
self.follower_arms[name].write(
"Operating_Mode", 4, all_motors_except_gripper
)
# Use 'position control current based' for follower gripper to be limited by the limit of the current.
# It can grasp an object without forcing too much even tho,
@ -442,7 +467,9 @@ class ManipulatorRobot:
before_lread_t = time.perf_counter()
leader_pos[name] = self.leader_arms[name].read("Present_Position")
leader_pos[name] = torch.from_numpy(leader_pos[name])
self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t
self.logs[f"read_leader_{name}_pos_dt_s"] = (
time.perf_counter() - before_lread_t
)
# Send goal position to the follower
follower_goal_pos = {}
@ -463,14 +490,18 @@ class ManipulatorRobot:
if self.config.max_relative_target is not None:
present_pos = self.follower_arms[name].read("Present_Position")
present_pos = torch.from_numpy(present_pos)
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
goal_pos = ensure_safe_goal_position(
goal_pos, present_pos, self.config.max_relative_target
)
# Used when record_data=True
follower_goal_pos[name] = goal_pos
goal_pos = goal_pos.numpy().astype(np.int32)
self.follower_arms[name].write("Goal_Position", goal_pos)
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = (
time.perf_counter() - before_fwrite_t
)
# Early exit when recording data is not requested
if not record_data:
@ -483,7 +514,9 @@ class ManipulatorRobot:
before_fread_t = time.perf_counter()
follower_pos[name] = self.follower_arms[name].read("Present_Position")
follower_pos[name] = torch.from_numpy(follower_pos[name])
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
self.logs[f"read_follower_{name}_pos_dt_s"] = (
time.perf_counter() - before_fread_t
)
# Create state by concatenating follower current position
state = []
@ -505,8 +538,12 @@ class ManipulatorRobot:
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
"delta_timestamp_s"
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
# Populate output dictionnaries
obs_dict, action_dict = {}, {}
@ -530,7 +567,9 @@ class ManipulatorRobot:
before_fread_t = time.perf_counter()
follower_pos[name] = self.follower_arms[name].read("Present_Position")
follower_pos[name] = torch.from_numpy(follower_pos[name])
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
self.logs[f"read_follower_{name}_pos_dt_s"] = (
time.perf_counter() - before_fread_t
)
# Create state by concatenating follower current position
state = []
@ -545,8 +584,12 @@ class ManipulatorRobot:
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
"delta_timestamp_s"
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
# Populate output dictionnaries and format to pytorch
obs_dict = {}
@ -592,7 +635,9 @@ class ManipulatorRobot:
if self.config.max_relative_target is not None:
present_pos = self.follower_arms[name].read("Present_Position")
present_pos = torch.from_numpy(present_pos)
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
goal_pos = ensure_safe_goal_position(
goal_pos, present_pos, self.config.max_relative_target
)
# Save tensor to concat and return
action_sent.append(goal_pos)

View File

@ -52,7 +52,9 @@ class StretchRobot(StretchAPI):
def connect(self) -> None:
self.is_connected = self.startup()
if not self.is_connected:
print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'")
print(
"Another process is already using Stretch. Try running 'stretch_free_robot_process.py'"
)
raise ConnectionError()
for name in self.cameras:
@ -60,7 +62,9 @@ class StretchRobot(StretchAPI):
self.is_connected = self.is_connected and self.cameras[name].is_connected
if not self.is_connected:
print("Could not connect to the cameras, check that all cameras are plugged-in.")
print(
"Could not connect to the cameras, check that all cameras are plugged-in."
)
raise ConnectionError()
self.run_calibration()
@ -105,8 +109,12 @@ class StretchRobot(StretchAPI):
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
"delta_timestamp_s"
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
# Populate output dictionnaries
obs_dict, action_dict = {}, {}
@ -150,8 +158,12 @@ class StretchRobot(StretchAPI):
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
"delta_timestamp_s"
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
# Populate output dictionnaries
obs_dict = {}

View File

@ -34,7 +34,8 @@ class RobotDeviceNotConnectedError(Exception):
"""Exception raised when the robot device is not connected."""
def __init__(
self, message="This robot device is not connected. Try calling `robot_device.connect()` first."
self,
message="This robot device is not connected. Try calling `robot_device.connect()` first.",
):
self.message = message
super().__init__(self.message)

View File

@ -17,7 +17,9 @@ import importlib
import logging
def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
def is_package_available(
pkg_name: str, return_version: bool = False
) -> tuple[bool, str] | bool:
"""Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py
Check if the package spec exists and grab its version to avoid importing a local directory.
**Note:** this doesn't work for all packages.

View File

@ -28,7 +28,9 @@ def write_video(video_path, stacked_frames, fps):
# Filter out DeprecationWarnings raised from pkg_resources
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning
"ignore",
"pkg_resources is deprecated as an API",
category=DeprecationWarning,
)
imageio.mimsave(video_path, stacked_frames, fps=fps)

View File

@ -143,7 +143,10 @@ def _relative_path_between(path1: Path, path2: Path) -> Path:
except ValueError: # most likely because path1 is not a subpath of path2
common_parts = Path(osp.commonpath([path1, path2])).parts
return Path(
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
"/".join(
[".."] * (len(path2.parts) - len(common_parts))
+ list(path1.parts[len(common_parts) :])
)
)
@ -154,10 +157,26 @@ def print_cuda_memory_usage():
gc.collect()
# Also clear the cache if you want to fully release the memory
torch.cuda.empty_cache()
print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2))
print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2))
print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2))
print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2))
print(
"Current GPU Memory Allocated: {:.2f} MB".format(
torch.cuda.memory_allocated(0) / 1024**2
)
)
print(
"Maximum GPU Memory Allocated: {:.2f} MB".format(
torch.cuda.max_memory_allocated(0) / 1024**2
)
)
print(
"Current GPU Memory Reserved: {:.2f} MB".format(
torch.cuda.memory_reserved(0) / 1024**2
)
)
print(
"Maximum GPU Memory Reserved: {:.2f} MB".format(
torch.cuda.max_memory_reserved(0) / 1024**2
)
)
def capture_timestamp_utc():
@ -206,7 +225,12 @@ def has_method(cls: object, method_name: str):
class TimerManager:
def __init__(self, elapsed_time_list: list[float] | None = None, label="Elapsed time", log=True):
def __init__(
self,
elapsed_time_list: list[float] | None = None,
label="Elapsed time",
log=True,
):
self.label = label
self.elapsed_time_list = elapsed_time_list
self.log = log

View File

@ -9,7 +9,7 @@ env:
action_dim: 6
fps: ${fps}
device: mps
wrapper:
crop_params_dict:
observation.images.front: [102, 43, 358, 523]
@ -28,4 +28,4 @@ env:
reward_classifier:
pretrained_path: outputs/classifier/13-02-random-sample-resnet10-frozen/checkpoints/best/pretrained_model
config_path: lerobot/configs/policy/hilserl_classifier.yaml

View File

@ -66,7 +66,7 @@ policy:
observation.image: [3, 64, 64]
output_shapes:
action: [7]
camera_number: 1
# Normalization / Unnormalization
@ -79,7 +79,7 @@ policy:
# 1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00,
# -3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00,
# -6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01,
# 8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01]
# 8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01]
# max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400,
# 0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163,

View File

@ -95,20 +95,26 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
break
if motor_index == -1:
raise ValueError("No motors detected. Please ensure you have one motor connected.")
raise ValueError(
"No motors detected. Please ensure you have one motor connected."
)
print(f"Motor index found at: {motor_index}")
if brand == "feetech":
# Allows ID and BAUDRATE to be written in memory
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
motor_bus.write_with_motor_ids(
motor_bus.motor_models, motor_index, "Lock", 0
)
if baudrate != baudrate_des:
print(f"Setting its baudrate to {baudrate_des}")
baudrate_idx = list(series_baudrate_table.values()).index(baudrate_des)
# The write can fail, so we allow retries
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx)
motor_bus.write_with_motor_ids(
motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx
)
time.sleep(0.5)
motor_bus.set_bus_baudrate(baudrate_des)
present_baudrate_idx = motor_bus.read_with_motor_ids(
@ -123,7 +129,9 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "ID", motor_idx_des)
present_idx = motor_bus.read_with_motor_ids(motor_bus.motor_models, motor_idx_des, "ID", num_retry=2)
present_idx = motor_bus.read_with_motor_ids(
motor_bus.motor_models, motor_idx_des, "ID", num_retry=2
)
if present_idx != motor_idx_des:
raise OSError("Failed to write index.")
@ -151,12 +159,29 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=str, required=True, help="Motors bus port (e.g. dynamixel,feetech)")
parser.add_argument("--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)")
parser.add_argument("--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)")
parser.add_argument("--ID", type=int, required=True, help="Desired ID of the current motor (e.g. 1,2,3)")
parser.add_argument(
"--baudrate", type=int, default=1000000, help="Desired baudrate for the motor (default: 1000000)"
"--port",
type=str,
required=True,
help="Motors bus port (e.g. dynamixel,feetech)",
)
parser.add_argument(
"--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)"
)
parser.add_argument(
"--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)"
)
parser.add_argument(
"--ID",
type=int,
required=True,
help="Desired ID of the current motor (e.g. 1,2,3)",
)
parser.add_argument(
"--baudrate",
type=int,
default=1000000,
help="Desired baudrate for the motor (default: 1000000)",
)
args = parser.parse_args()

View File

@ -136,7 +136,11 @@ def init_sim_calibration(robot, cfg):
axis_directions = np.array(cfg.get("axis_directions", [1]))
offsets = np.array(cfg.get("offsets", [0])) * np.pi
return {"start_pos": start_pos, "axis_directions": axis_directions, "offsets": offsets}
return {
"start_pos": start_pos,
"axis_directions": axis_directions,
"offsets": offsets,
}
def real_positions_to_sim(real_positions, axis_directions, start_pos, offsets):
@ -157,7 +161,10 @@ def teleoperate(env, robot: Robot, process_action_fn, teleop_time_s=None):
leader_pos = robot.leader_arms.main.read("Present_Position")
action = process_action_fn(leader_pos)
env.step(np.expand_dims(action, 0))
if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s:
if (
teleop_time_s is not None
and time.perf_counter() - start_teleop_t > teleop_time_s
):
print("Teleoperation processes finished.")
break
@ -189,19 +196,27 @@ def record(
# Load pretrained policy
extra_features = (
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}}
if assign_rewards
else None
)
policy = None
if pretrained_policy_name_or_path is not None:
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
policy, policy_fps, device, use_amp = init_policy(
pretrained_policy_name_or_path, policy_overrides
)
if fps is None:
fps = policy_fps
logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).")
logging.warning(
f"No fps provided, so using the fps from policy config ({policy_fps})."
)
if policy is None and process_action_from_leader is None:
raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.")
raise ValueError(
"Either policy or process_action_fn has to be set to enable control in sim."
)
# initialize listener before sim env
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
@ -243,7 +258,11 @@ def record(
"shape": env.observation_space[obs_key].shape,
}
features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None}
features["action"] = {
"dtype": "float32",
"shape": env.action_space.shape,
"names": None,
}
features = {**features, **extra_features}
# Create empty dataset or load existing saved episodes
@ -344,7 +363,9 @@ def record(
if events["stop_recording"] or recorded_episodes >= num_episodes:
break
else:
logging.info("Waiting for a few seconds before starting next episode recording...")
logging.info(
"Waiting for a few seconds before starting next episode recording..."
)
busy_wait(3)
log_say("Stop recording", play_sounds, blocking=True)
@ -362,7 +383,12 @@ def record(
def replay(
env, root: Path, repo_id: str, episode: int, fps: int | None = None, local_files_only: bool = True
env,
root: Path,
repo_id: str,
episode: int,
fps: int | None = None,
local_files_only: bool = True,
):
env = env()
@ -409,7 +435,10 @@ if __name__ == "__main__":
parser_record = subparsers.add_parser("record", parents=[base_parser])
parser_record.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
"--fps",
type=none_or_int,
default=None,
help="Frames per second (set to None to disable)",
)
parser_record.add_argument(
"--root",
@ -435,7 +464,9 @@ if __name__ == "__main__":
required=True,
help="A description of the task preformed during recording that can be used as a language instruction.",
)
parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.")
parser_record.add_argument(
"--num-episodes", type=int, default=50, help="Number of episodes to record."
)
parser_record.add_argument(
"--run-compute-stats",
type=int,
@ -496,7 +527,10 @@ if __name__ == "__main__":
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
parser_replay.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
"--fps",
type=none_or_int,
default=None,
help="Frames per second (set to None to disable)",
)
parser_replay.add_argument(
"--root",
@ -510,7 +544,9 @@ if __name__ == "__main__":
default="lerobot/test",
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
)
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episodes to replay.")
parser_replay.add_argument(
"--episode", type=int, default=0, help="Index of the episodes to replay."
)
args = parser.parse_args()

View File

@ -59,7 +59,11 @@ np_version = np.__version__ if HAS_NP else "N/A"
torch_version = torch.__version__ if HAS_TORCH else "N/A"
torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A"
cuda_version = (
torch._C._cuda_getCompiledVersion()
if HAS_TORCH and torch.version.cuda is not None
else "N/A"
)
# TODO(aliberts): refactor into an actual command `lerobot env`
@ -77,7 +81,9 @@ def display_sys_info() -> dict:
"Using GPU in script?": "<fill in>",
# "Using distributed or parallel set-up in script?": "<fill in>",
}
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n")
print(
"\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n"
)
print(format_dict(info))
return info

View File

@ -170,7 +170,10 @@ def rollout(
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
# available of none of the envs finished.
if "final_info" in info:
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
successes = [
info["is_success"] if info is not None else False
for info in info["final_info"]
]
else:
successes = [False] * env.num_envs
@ -184,9 +187,13 @@ def rollout(
step += 1
running_success_rate = (
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean()
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any")
.numpy()
.mean()
)
progbar.set_postfix(
{"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"}
)
progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"})
progbar.update()
# Track the final observation.
@ -204,7 +211,9 @@ def rollout(
if return_observations:
stacked_observations = {}
for key in all_observations[0]:
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
stacked_observations[key] = torch.stack(
[obs[key] for obs in all_observations], dim=1
)
ret["observation"] = stacked_observations
if hasattr(policy, "use_original_modules"):
@ -266,7 +275,9 @@ def eval_policy(
return
n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs)
if isinstance(env, gym.vector.SyncVectorEnv):
ep_frames.append(np.stack([env.envs[i].render() for i in range(n_to_render_now)])) # noqa: B023
ep_frames.append(
np.stack([env.envs[i].render() for i in range(n_to_render_now)])
) # noqa: B023
elif isinstance(env, gym.vector.AsyncVectorEnv):
# Here we must render all frames and discard any we don't need.
ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
@ -278,7 +289,9 @@ def eval_policy(
episode_data: dict | None = None
# we dont want progress bar when we use slurm, since it clutters the logs
progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm())
progbar = trange(
n_batches, desc="Stepping through eval batches", disable=inside_slurm()
)
for batch_ix in progbar:
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
# step.
@ -289,7 +302,8 @@ def eval_policy(
seeds = None
else:
seeds = range(
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
start_seed + (batch_ix * env.num_envs),
start_seed + ((batch_ix + 1) * env.num_envs),
)
rollout_data = rollout(
env,
@ -307,13 +321,22 @@ def eval_policy(
# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
mask = (
torch.arange(n_steps)
<= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)
).int()
# Extend metrics.
batch_sum_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "sum")
batch_sum_rewards = einops.reduce(
(rollout_data["reward"] * mask), "b n -> b", "sum"
)
sum_rewards.extend(batch_sum_rewards.tolist())
batch_max_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "max")
batch_max_rewards = einops.reduce(
(rollout_data["reward"] * mask), "b n -> b", "max"
)
max_rewards.extend(batch_max_rewards.tolist())
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
batch_successes = einops.reduce(
(rollout_data["success"] * mask), "b n -> b", "any"
)
all_successes.extend(batch_successes.tolist())
if seeds:
all_seeds.extend(seeds)
@ -326,17 +349,27 @@ def eval_policy(
rollout_data,
done_indices,
start_episode_index=batch_ix * env.num_envs,
start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)),
start_data_index=(
0
if episode_data is None
else (episode_data["index"][-1].item() + 1)
),
fps=env.unwrapped.metadata["render_fps"],
)
if episode_data is None:
episode_data = this_episode_data
else:
# Some sanity checks to make sure we are correctly compiling the data.
assert episode_data["episode_index"][-1] + 1 == this_episode_data["episode_index"][0]
assert (
episode_data["episode_index"][-1] + 1
== this_episode_data["episode_index"][0]
)
assert episode_data["index"][-1] + 1 == this_episode_data["index"][0]
# Concatenate the episode data.
episode_data = {k: torch.cat([episode_data[k], this_episode_data[k]]) for k in episode_data}
episode_data = {
k: torch.cat([episode_data[k], this_episode_data[k]])
for k in episode_data
}
# Maybe render video for visualization.
if max_episodes_rendered > 0 and len(ep_frames) > 0:
@ -354,7 +387,9 @@ def eval_policy(
target=write_video,
args=(
str(video_path),
stacked_frames[: done_index + 1], # + 1 to capture the last observation
stacked_frames[
: done_index + 1
], # + 1 to capture the last observation
env.unwrapped.metadata["render_fps"],
),
)
@ -363,7 +398,9 @@ def eval_policy(
n_episodes_rendered += 1
progbar.set_postfix(
{"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"}
{
"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"
}
)
# Wait till all video rendering threads are done.
@ -409,7 +446,11 @@ def eval_policy(
def _compile_episode_data(
rollout_data: dict, done_indices: Tensor, start_episode_index: int, start_data_index: int, fps: float
rollout_data: dict,
done_indices: Tensor,
start_episode_index: int,
start_data_index: int,
fps: float,
) -> dict:
"""Convenience function for `eval_policy(return_episode_data=True)`
@ -427,12 +468,16 @@ def _compile_episode_data(
# Here we do `num_frames - 1` as we don't want to include the last observation frame just yet.
ep_dict = {
"action": rollout_data["action"][ep_ix, : num_frames - 1],
"episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)),
"episode_index": torch.tensor(
[start_episode_index + ep_ix] * (num_frames - 1)
),
"frame_index": torch.arange(0, num_frames - 1, 1),
"timestamp": torch.arange(0, num_frames - 1, 1) / fps,
"next.done": rollout_data["done"][ep_ix, : num_frames - 1],
"next.success": rollout_data["success"][ep_ix, : num_frames - 1],
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32),
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(
torch.float32
),
}
# For the last observation frame, all other keys will just be copy padded.
@ -448,7 +493,9 @@ def _compile_episode_data(
for key in ep_dicts[0]:
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
data_dict["index"] = torch.arange(
start_data_index, start_data_index + total_frames, 1
)
return data_dict

View File

@ -46,7 +46,11 @@ import torch
from tqdm import trange
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position
from lerobot.common.robot_devices.control_utils import (
busy_wait,
is_headless,
reset_follower_position,
)
from lerobot.common.robot_devices.robots.factory import Robot, make_robot
from lerobot.common.utils.utils import (
init_hydra_config,
@ -60,13 +64,19 @@ def get_classifier(pretrained_path, config_path):
return
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
ClassifierConfig,
)
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
)
cfg = init_hydra_config(config_path)
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
classifier_config.num_cameras = len(
cfg.training.image_keys
) # TODO automate these paths
model = Classifier(classifier_config)
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
model = model.to("mps")
@ -151,11 +161,17 @@ def rollout(
images = []
for key in image_keys:
if display_cameras:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.imshow(
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
)
cv2.waitKey(1)
images.append(observation[key].to("mps"))
reward = reward_classifier.predict_reward(images) if reward_classifier is not None else 0.0
reward = (
reward_classifier.predict_reward(images)
if reward_classifier is not None
else 0.0
)
all_rewards.append(reward)
# print("REWARD : ", reward)
@ -219,11 +235,19 @@ def eval_policy(
start_eval = time.perf_counter()
progbar = trange(n_episodes, desc="Evaluating policy on real robot")
reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file)
reward_classifier = get_classifier(
reward_classifier_pretrained_path, reward_classifier_config_file
)
for _ in progbar:
rollout_data = rollout(
robot, policy, reward_classifier, fps, control_time_s, use_amp, display_cameras
robot,
policy,
reward_classifier,
fps,
control_time_s,
use_amp,
display_cameras,
)
rollouts.append(rollout_data)
@ -289,7 +313,9 @@ def init_keyboard_listener():
print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True
elif key == keyboard.Key.left:
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
print(
"Left arrow key pressed. Exiting loop and rerecord the last episode..."
)
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.space:
@ -301,7 +327,10 @@ def init_keyboard_listener():
"Place the leader in similar pose to the follower and press space again."
)
events["pause_policy"] = True
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
log_say(
"Human intervention stage. Get ready to take over.",
play_sounds=True,
)
else:
events["human_intervention_step"] = True
print("Space key pressed. Human intervention starting.")
@ -351,7 +380,9 @@ if __name__ == "__main__":
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
),
)
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
parser.add_argument(
"--revision", help="Optionally provide the Hugging Face Hub revision ID."
)
parser.add_argument(
"--out-dir",
help=(
@ -360,7 +391,8 @@ if __name__ == "__main__":
),
)
parser.add_argument(
"--display-cameras", help=("Whether to display the camera feed while the rollout is happening")
"--display-cameras",
help=("Whether to display the camera feed while the rollout is happening"),
)
parser.add_argument(
"--reward-classifier-pretrained-path",

View File

@ -32,9 +32,13 @@ def find_port():
print(f"The port of this MotorsBus is '{port}'")
print("Reconnect the USB cable.")
elif len(ports_diff) == 0:
raise OSError(f"Could not detect the port. No difference was found ({ports_diff}).")
raise OSError(
f"Could not detect the port. No difference was found ({ports_diff})."
)
else:
raise OSError(f"Could not detect the port. More than one port was found ({ports_diff}).")
raise OSError(
f"Could not detect the port. More than one port was found ({ports_diff})."
)
if __name__ == "__main__":

View File

@ -56,24 +56,42 @@ from safetensors.torch import save_file
from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
from lerobot.common.datasets.utils import create_branch, create_lerobot_dataset_card, flatten_dict
from lerobot.common.datasets.utils import (
create_branch,
create_lerobot_dataset_card,
flatten_dict,
)
def get_from_raw_to_lerobot_format_fn(raw_format: str):
if raw_format == "pusht_zarr":
from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format
from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import (
from_raw_to_lerobot_format,
)
elif raw_format == "umi_zarr":
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import (
from_raw_to_lerobot_format,
)
elif raw_format == "aloha_hdf5":
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import (
from_raw_to_lerobot_format,
)
elif raw_format in ["rlds", "openx"]:
from lerobot.common.datasets.push_dataset_to_hub.openx_rlds_format import from_raw_to_lerobot_format
from lerobot.common.datasets.push_dataset_to_hub.openx_rlds_format import (
from_raw_to_lerobot_format,
)
elif raw_format == "dora_parquet":
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import (
from_raw_to_lerobot_format,
)
elif raw_format == "xarm_pkl":
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import (
from_raw_to_lerobot_format,
)
elif raw_format == "cam_png":
from lerobot.common.datasets.push_dataset_to_hub.cam_png_format import from_raw_to_lerobot_format
from lerobot.common.datasets.push_dataset_to_hub.cam_png_format import (
from_raw_to_lerobot_format,
)
else:
raise ValueError(
f"The selected {raw_format} can't be found. Did you add it to `lerobot/scripts/push_dataset_to_hub.py::get_from_raw_to_lerobot_format_fn`?"
@ -83,7 +101,10 @@ def get_from_raw_to_lerobot_format_fn(raw_format: str):
def save_meta_data(
info: dict[str, Any], stats: dict, episode_data_index: dict[str, list], meta_data_dir: Path
info: dict[str, Any],
stats: dict,
episode_data_index: dict[str, list],
meta_data_dir: Path,
):
meta_data_dir.mkdir(parents=True, exist_ok=True)
@ -97,12 +118,16 @@ def save_meta_data(
save_file(flatten_dict(stats), stats_path)
# save episode_data_index
episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
episode_data_index = {
key: torch.tensor(episode_data_index[key]) for key in episode_data_index
}
ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
save_file(episode_data_index, ep_data_idx_path)
def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str | None):
def push_meta_data_to_hub(
repo_id: str, meta_data_dir: str | Path, revision: str | None
):
"""Expect all meta data files to be all stored in a single "meta_data" directory.
On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root.
"""
@ -187,7 +212,9 @@ def push_dataset_to_hub(
if force_override:
shutil.rmtree(local_dir)
elif not resume:
raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.")
raise ValueError(
f"`local_dir` already exists ({local_dir}). Use `--force-override 1`."
)
meta_data_dir = local_dir / "meta_data"
videos_dir = local_dir / "videos"
@ -223,7 +250,9 @@ def push_dataset_to_hub(
stats = compute_stats(lerobot_dataset, batch_size, num_workers)
if local_dir:
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
hf_dataset = hf_dataset.with_format(
None
) # to remove transforms that cant be saved
hf_dataset.save_to_disk(str(local_dir / "train"))
if push_to_hub or local_dir:

View File

@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import random
from typing import Any, Callable, Optional, Sequence, TypedDict
import io
@ -737,7 +736,6 @@ def concatenate_batch_transitions(
if __name__ == "__main__":
import numpy as np
from tempfile import TemporaryDirectory
# ===== Test 1: Create and use a synthetic ReplayBuffer =====
@ -1139,7 +1137,7 @@ if __name__ == "__main__":
savings_percent = (std_mem - opt_mem) / std_mem * 100
print(f"\nMemory optimization result:")
print("\nMemory optimization result:")
print(f"- Standard buffer state memory: {std_mem / (1024 * 1024):.2f} MB")
print(f"- Optimized buffer state memory: {opt_mem / (1024 * 1024):.2f} MB")
print(f"- Memory savings for state tensors: {savings_percent:.1f}%")

View File

@ -225,7 +225,9 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.")
parser = argparse.ArgumentParser(
description="Crop rectangular ROIs from a LeRobot dataset."
)
parser.add_argument(
"--repo-id",
type=str,
@ -247,7 +249,9 @@ if __name__ == "__main__":
args = parser.parse_args()
local_files_only = args.root is not None
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=local_files_only)
dataset = LeRobotDataset(
repo_id=args.repo_id, root=args.root, local_files_only=local_files_only
)
images = get_image_from_lerobot_dataset(dataset)
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
@ -256,7 +260,7 @@ if __name__ == "__main__":
if args.crop_params_path is None:
rois = select_square_roi_for_images(images)
else:
with open(args.crop_params_path, "r") as f:
with open(args.crop_params_path) as f:
rois = json.load(f)
# rois = {

View File

@ -31,7 +31,9 @@ def find_joint_bounds(
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.imshow(
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
)
cv2.waitKey(1)
timestamp = time.perf_counter() - start_episode_t
@ -57,7 +59,12 @@ if __name__ == "__main__":
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds")
parser.add_argument(
"--control-time-s",
type=float,
default=20,
help="Maximum episode length in seconds",
)
args = parser.parse_args()
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)

View File

@ -146,7 +146,7 @@ def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
def initialize_replay_buffer(
cfg: DictConfig, logger: Logger, device: str, storage_device:str
cfg: DictConfig, logger: Logger, device: str, storage_device: str
) -> ReplayBuffer:
if not cfg.resume:
return ReplayBuffer(

View File

@ -10,7 +10,9 @@ from typing import Any
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, torch.Tensor]:
def preprocess_maniskill_observation(
observations: dict[str, np.ndarray],
) -> dict[str, torch.Tensor]:
"""Convert environment observation to LeRobot format observation.
Args:
observation: Dictionary of observation batches from a Gym vector environment.
@ -62,7 +64,9 @@ class ManiSkillCompat(gym.Wrapper):
new_action_space_shape = env.action_space.shape[-1]
new_low = np.squeeze(env.action_space.low, axis=0)
new_high = np.squeeze(env.action_space.high, axis=0)
self.action_space = gym.spaces.Box(low=new_low, high=new_high, shape=(new_action_space_shape,))
self.action_space = gym.spaces.Box(
low=new_low, high=new_high, shape=(new_action_space_shape,)
)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
@ -81,7 +85,9 @@ class ManiSkillCompat(gym.Wrapper):
class ManiSkillActionWrapper(gym.ActionWrapper):
def __init__(self, env):
super().__init__(env)
self.action_space = gym.spaces.Tuple(spaces=(env.action_space, gym.spaces.Discrete(2)))
self.action_space = gym.spaces.Tuple(
spaces=(env.action_space, gym.spaces.Discrete(2))
)
def action(self, action):
action, telop = action
@ -95,7 +101,9 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper):
action_space_agent: gym.spaces.Box = env.action_space[0]
action_space_agent.low = action_space_agent.low * multiply_factor
action_space_agent.high = action_space_agent.high * multiply_factor
self.action_space = gym.spaces.Tuple(spaces=(action_space_agent, gym.spaces.Discrete(2)))
self.action_space = gym.spaces.Tuple(
spaces=(action_space_agent, gym.spaces.Discrete(2))
)
def step(self, action):
if isinstance(action, tuple):
@ -137,7 +145,9 @@ def make_maniskill(
env = ManiSkillObservationWrapper(env, device=cfg.env.device)
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
env._max_episode_steps = env.max_episode_steps = (
50 # gym_utils.find_max_episode_steps_value(env)
)
env.unwrapped.metadata["render_fps"] = 20
env = ManiSkillCompat(env)
env = ManiSkillActionWrapper(env)
@ -149,10 +159,11 @@ def make_maniskill(
if __name__ == "__main__":
import argparse
import hydra
from omegaconf import OmegaConf
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="lerobot/configs/env/maniskill_example.yaml")
parser.add_argument(
"--config", type=str, default="lerobot/configs/env/maniskill_example.yaml"
)
args = parser.parse_args()
# Initialize config

View File

@ -73,7 +73,9 @@ def make_optimizer_and_scheduler(cfg, policy):
},
]
optimizer = torch.optim.AdamW(
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
optimizer_params_dicts,
lr=cfg.training.lr,
weight_decay=cfg.training.weight_decay,
)
lr_scheduler = None
elif cfg.policy.name == "diffusion":
@ -100,14 +102,23 @@ def make_optimizer_and_scheduler(cfg, policy):
optimizer = torch.optim.Adam(
[
{"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
{"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr},
{"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr},
{
"params": policy.critic_ensemble.parameters(),
"lr": policy.config.critic_lr,
},
{
"params": policy.temperature.parameters(),
"lr": policy.config.temperature_lr,
},
]
)
lr_scheduler = None
elif cfg.policy.name == "vqbet":
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
from lerobot.common.policies.vqbet.modeling_vqbet import (
VQBeTOptimizer,
VQBeTScheduler,
)
optimizer = VQBeTOptimizer(policy, cfg)
lr_scheduler = VQBeTScheduler(optimizer, cfg)
@ -215,7 +226,9 @@ def train(cfg: TrainPipelineConfig):
if cfg.resume:
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_learnable_params = sum(
p.numel() for p in policy.parameters() if p.requires_grad
)
num_total_params = sum(p.numel() for p in policy.parameters())
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")

View File

@ -14,7 +14,6 @@
import logging
import time
from contextlib import nullcontext
from pathlib import Path
from pprint import pformat
import hydra
@ -28,14 +27,16 @@ from termcolor import colored
from torch import optim
from torch.autograd import profiler
from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler, random_split
from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler
from tqdm import tqdm
from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
ClassifierConfig,
)
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.common.utils.utils import (
format_big_number,
@ -50,7 +51,11 @@ def get_model(cfg, logger): # noqa I001
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
model = Classifier(classifier_config)
if cfg.resume:
model.load_state_dict(Classifier.from_pretrained(str(logger.last_pretrained_model_dir)).state_dict())
model.load_state_dict(
Classifier.from_pretrained(
str(logger.last_pretrained_model_dir)
).state_dict()
)
return model
@ -62,7 +67,9 @@ def create_balanced_sampler(dataset, cfg):
class_weights = 1.0 / counts.float()
sample_weights = class_weights[labels]
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
return WeightedRandomSampler(
weights=sample_weights, num_samples=len(sample_weights), replacement=True
)
def support_amp(device: torch.device, cfg: DictConfig) -> bool:
@ -71,7 +78,9 @@ def support_amp(device: torch.device, cfg: DictConfig) -> bool:
return cfg.training.use_amp and device.type in ("cuda", "cpu")
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
def train_epoch(
model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg
):
# Single epoch training loop with AMP support and progress tracking
model.train()
correct = 0
@ -85,7 +94,11 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
labels = batch[cfg.training.label_key].float().to(device)
# Forward pass with optional AMP
with torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext():
with (
torch.autocast(device_type=device.type)
if support_amp(device, cfg)
else nullcontext()
):
outputs = model(images)
loss = criterion(outputs.logits, labels)
@ -130,7 +143,9 @@ def validate(model, val_loader, criterion, device, logger, cfg):
with (
torch.no_grad(),
torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(),
torch.autocast(device_type=device.type)
if support_amp(device, cfg)
else nullcontext(),
):
for batch in tqdm(val_loader, desc="Validation"):
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
@ -143,7 +158,9 @@ def validate(model, val_loader, criterion, device, logger, cfg):
):
outputs = model(images)
inference_times.append(
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
next(
x for x in prof.key_averages() if x.key == "model_inference"
).cpu_time
)
else:
outputs = model(images)
@ -161,16 +178,24 @@ def validate(model, val_loader, criterion, device, logger, cfg):
# Log sample predictions for visualization
if len(samples) < cfg.eval.num_samples_to_log:
for i in range(min(cfg.eval.num_samples_to_log - len(samples), len(images))):
for i in range(
min(cfg.eval.num_samples_to_log - len(samples), len(images))
):
if model.config.num_classes == 2:
confidence = round(outputs.probabilities[i].item(), 3)
else:
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
confidence = [
round(prob, 3) for prob in outputs.probabilities[i].tolist()
]
samples.append(
{
**{
f"image_{img_key}": wandb.Image(images[img_idx][i].cpu())
for img_idx, img_key in enumerate(cfg.training.image_keys)
f"image_{img_key}": wandb.Image(
images[img_idx][i].cpu()
)
for img_idx, img_key in enumerate(
cfg.training.image_keys
)
},
"true_label": labels[i].item(),
"predicted": predictions[i].item(),
@ -238,15 +263,24 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step):
elif device.type == "mps":
torch.mps.synchronize()
with profiler.profile(record_shapes=True) as prof, profiler.record_function("model_inference"):
with (
profiler.profile(record_shapes=True) as prof,
profiler.record_function("model_inference"),
):
_ = model(x)
inference_times.append(
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
next(
x for x in prof.key_averages() if x.key == "model_inference"
).cpu_time
)
inference_times = np.array(inference_times)
avg, median, std = inference_times.mean(), np.median(inference_times), inference_times.std()
avg, median, std = (
inference_times.mean(),
np.median(inference_times),
inference_times.std(),
)
print(
f"Inference time mean: {avg:.2f} us, median: {median:.2f} us, std: {std:.2f} us, with {iters} iterations on {device.type} device"
)
@ -264,7 +298,11 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step):
return avg, median, std
@hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier")
@hydra.main(
version_base="1.2",
config_path="../configs/policy",
config_name="hilserl_classifier",
)
def train(cfg: DictConfig) -> None:
# Main training pipeline with support for resuming training
logging.info(OmegaConf.to_yaml(cfg))
@ -278,7 +316,9 @@ def train(cfg: DictConfig) -> None:
# Setup dataset and dataloaders
dataset = LeRobotDataset(
cfg.dataset_repo_id, root=cfg.dataset_root, local_files_only=cfg.local_files_only
cfg.dataset_repo_id,
root=cfg.dataset_root,
local_files_only=cfg.local_files_only,
)
logging.info(f"Dataset size: {len(dataset)}")
@ -314,7 +354,9 @@ def train(cfg: DictConfig) -> None:
"You have set resume=True, but there is no model checkpoint in "
f"{Logger.get_last_checkpoint_dir(out_dir)}"
)
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
checkpoint_cfg_path = str(
Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml"
)
logging.info(
colored(
"You have set resume=True, indicating that you wish to resume a run",
@ -327,7 +369,9 @@ def train(cfg: DictConfig) -> None:
# Check for differences between the checkpoint configuration and provided configuration.
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
resolve_delta_timestamps(cfg)
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
diff = DeepDiff(
OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)
)
# Ignore the `resume` and parameters.
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
del diff["values_changed"]["root['resume']"]
@ -346,7 +390,11 @@ def train(cfg: DictConfig) -> None:
optimizer = optim.AdamW(model.parameters(), lr=cfg.training.learning_rate)
# Use BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multi-class
criterion = nn.BCEWithLogitsLoss() if model.config.num_classes == 2 else nn.CrossEntropyLoss()
criterion = (
nn.BCEWithLogitsLoss()
if model.config.num_classes == 2
else nn.CrossEntropyLoss()
)
grad_scaler = GradScaler(enabled=cfg.training.use_amp)
# Log model parameters
@ -362,7 +410,17 @@ def train(cfg: DictConfig) -> None:
for epoch in range(cfg.training.num_epochs):
logging.info(f"\nEpoch {epoch+1}/{cfg.training.num_epochs}")
train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg)
train_epoch(
model,
train_loader,
criterion,
optimizer,
grad_scaler,
device,
logger,
step,
cfg,
)
# Periodic validation
if cfg.training.eval_freq > 0 and (epoch + 1) % cfg.training.eval_freq == 0:

View File

@ -22,7 +22,6 @@ from typing import Callable, Optional, Sequence, TypedDict
import hydra
import torch
import torch.nn.functional as F
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from torch import nn
from tqdm import tqdm
@ -30,20 +29,17 @@ from tqdm import tqdm
# TODO: Remove the import of maniskill
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.envs.factory import make_env, make_maniskill_env
from lerobot.common.envs.utils import preprocess_maniskill_observation, preprocess_observation
from lerobot.common.envs.factory import make_maniskill_env
from lerobot.common.envs.utils import preprocess_maniskill_observation
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.utils import (
format_big_number,
get_safe_torch_device,
init_hydra_config,
init_logging,
set_global_seed,
)
from lerobot.scripts.eval import eval_policy
def make_optimizers_and_scheduler(cfg, policy):
@ -56,7 +52,9 @@ def make_optimizers_and_scheduler(cfg, policy):
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
)
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
optimizer_temperature = torch.optim.Adam(
params=[policy.log_alpha], lr=policy.config.critic_lr
)
lr_scheduler = None
optimizers = {
"actor": optimizer_actor,
@ -108,7 +106,9 @@ def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Te
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
# Gather pixels
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
cropped_hwcn = images_hwcn[
torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :
]
# cropped_hwcn => (B, crop_h, crop_w, C)
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
@ -198,8 +198,12 @@ class ReplayBuffer:
"""
# We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from
# a replay buffer than from a lerobot dataset.
replay_buffer = cls(capacity=len(lerobot_dataset), device=device, state_keys=state_keys)
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
replay_buffer = cls(
capacity=len(lerobot_dataset), device=device, state_keys=state_keys
)
list_transition = cls._lerobotdataset_to_transitions(
dataset=lerobot_dataset, state_keys=state_keys
)
# Fill the replay buffer with the lerobot dataset transitions
for data in list_transition:
replay_buffer.add(
@ -244,7 +248,9 @@ class ReplayBuffer:
# If not provided, you can either raise an error or define a default:
if state_keys is None:
raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.")
raise ValueError(
"You must provide a list of keys in `state_keys` that define your 'state'."
)
transitions: list[Transition] = []
num_frames = len(dataset)
@ -298,36 +304,40 @@ class ReplayBuffer:
# -- Build batched states --
batch_state = {}
for key in self.state_keys:
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to(
self.device
)
batch_state[key] = torch.cat(
[t["state"][key] for t in list_of_transitions], dim=0
).to(self.device)
if key.startswith("observation.image") and self.use_drq:
batch_state[key] = self.image_augmentation_function(batch_state[key])
# -- Build batched actions --
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device)
# -- Build batched rewards --
batch_rewards = torch.tensor([t["reward"] for t in list_of_transitions], dtype=torch.float32).to(
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(
self.device
)
# -- Build batched rewards --
batch_rewards = torch.tensor(
[t["reward"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
# -- Build batched next states --
batch_next_state = {}
for key in self.state_keys:
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to(
self.device
)
batch_next_state[key] = torch.cat(
[t["next_state"][key] for t in list_of_transitions], dim=0
).to(self.device)
if key.startswith("observation.image") and self.use_drq:
batch_next_state[key] = self.image_augmentation_function(batch_next_state[key])
batch_next_state[key] = self.image_augmentation_function(
batch_next_state[key]
)
# -- Build batched dones --
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
batch_dones = torch.tensor(
[t["done"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
batch_dones = torch.tensor(
[t["done"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
# Return a BatchTransition typed dict
return BatchTransition(
@ -344,7 +354,13 @@ def concatenate_batch_transitions(
) -> BatchTransition:
"""NOTE: Be careful it change the left_batch_transitions in place"""
left_batch_transitions["state"] = {
key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0)
key: torch.cat(
[
left_batch_transitions["state"][key],
right_batch_transition["state"][key],
],
dim=0,
)
for key in left_batch_transitions["state"]
}
left_batch_transitions["action"] = torch.cat(
@ -355,7 +371,11 @@ def concatenate_batch_transitions(
)
left_batch_transitions["next_state"] = {
key: torch.cat(
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0
[
left_batch_transitions["next_state"][key],
right_batch_transition["next_state"][key],
],
dim=0,
)
for key in left_batch_transitions["next_state"]
}
@ -407,7 +427,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online traning, we do not need dataset_stats
dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
if cfg.resume
else None,
device=device,
)
assert isinstance(policy, nn.Module)
@ -416,7 +438,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# TODO: Handle resume
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_learnable_params = sum(
p.numel() for p in policy.parameters() if p.requires_grad
)
num_total_params = sum(p.numel() for p in policy.parameters())
log_output_dir(out_dir)
@ -433,7 +457,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
obs = {key: obs[key].to(device, non_blocking=True) for key in obs}
replay_buffer = ReplayBuffer(
capacity=cfg.training.online_buffer_capacity, device=device, state_keys=cfg.policy.input_shapes.keys()
capacity=cfg.training.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
)
batch_size = cfg.training.batch_size
@ -455,12 +481,16 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
if interaction_step >= cfg.training.online_step_before_learning:
action = policy.select_action(batch=obs)
next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy())
next_obs, reward, done, truncated, info = online_env.step(
action.cpu().numpy()
)
else:
action = online_env.action_space.sample()
next_obs, reward, done, truncated, info = online_env.step(action)
# HACK
action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True)
action = torch.tensor(action, dtype=torch.float32).to(
device, non_blocking=True
)
# HACK: For maniskill
# next_obs = preprocess_observation(next_obs)
@ -470,14 +500,20 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Because we are using a single environment
# we can safely assume that the episode is done
if done[0] or truncated[0]:
logging.info(f"Global step {interaction_step}: Episode reward: {sum_reward_episode}")
logger.log_dict({"Sum episode reward": sum_reward_episode}, interaction_step)
logging.info(
f"Global step {interaction_step}: Episode reward: {sum_reward_episode}"
)
logger.log_dict(
{"Sum episode reward": sum_reward_episode}, interaction_step
)
sum_reward_episode = 0
# HACK: This is for maniskill
logging.info(
f"global step {interaction_step}: episode success: {info['success'].float().item()} \n"
)
logger.log_dict({"Episode success": info["success"].float().item()}, interaction_step)
logger.log_dict(
{"Episode success": info["success"].float().item()}, interaction_step
)
replay_buffer.add(
state=obs,
@ -551,7 +587,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
training_infos["loss_actor"] = loss_actor.item()
loss_temperature = policy.compute_loss_temperature(observations=observations)
loss_temperature = policy.compute_loss_temperature(
observations=observations
)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
optimizers["temperature"].step()
@ -573,7 +611,9 @@ def train_cli(cfg: dict):
)
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
def train_notebook(
out_dir=None, job_name=None, config_name="default", config_path="../configs"
):
from hydra import compose, initialize
hydra.core.global_hydra.GlobalHydra.instance().clear()

View File

@ -94,8 +94,12 @@ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
assert chw_float32_torch.dtype == torch.float32
assert chw_float32_torch.ndim == 3
c, h, w = chw_float32_torch.shape
assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}"
hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
assert (
c < h and c < w
), f"expect channel first images, but instead {chw_float32_torch.shape}"
hwc_uint8_numpy = (
(chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
)
return hwc_uint8_numpy

View File

@ -81,7 +81,11 @@ def run_server(
static_folder: Path,
template_folder: Path,
):
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
app = Flask(
__name__,
static_folder=static_folder.resolve(),
template_folder=template_folder.resolve(),
)
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
@app.route("/")
@ -138,8 +142,12 @@ def run_server(
)
)
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
@app.route(
"/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>"
)
def show_episode(
dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes
):
repo_id = f"{dataset_namespace}/{dataset_name}"
try:
if dataset is None:
@ -150,7 +158,9 @@ def run_server(
400,
)
dataset_version = (
dataset.meta._version if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
dataset.meta._version
if isinstance(dataset, LeRobotDataset)
else dataset.codebase_version
)
match = re.search(r"v(\d+)\.", dataset_version)
if match:
@ -171,15 +181,21 @@ def run_server(
}
if isinstance(dataset, LeRobotDataset):
video_paths = [
dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
dataset.meta.get_video_file_path(episode_id, key)
for key in dataset.meta.video_keys
]
videos_info = [
{"url": url_for("static", filename=video_path), "filename": video_path.parent.name}
{
"url": url_for("static", filename=video_path),
"filename": video_path.parent.name,
}
for video_path in video_paths
]
tasks = dataset.meta.episodes[episode_id]["tasks"]
else:
video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"]
video_keys = [
key for key, ft in dataset.features.items() if ft["dtype"] == "video"
]
videos_info = [
{
"url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
@ -198,16 +214,24 @@ def run_server(
)
response.raise_for_status()
# Split into lines and parse each line as JSON
tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()]
tasks_jsonl = [
json.loads(line) for line in response.text.splitlines() if line.strip()
]
filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id]
filtered_tasks_jsonl = [
row for row in tasks_jsonl if row["episode_index"] == episode_id
]
tasks = filtered_tasks_jsonl[0]["tasks"]
videos_info[0]["language_instruction"] = tasks
if episodes is None:
episodes = list(
range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
range(
dataset.num_episodes
if isinstance(dataset, LeRobotDataset)
else dataset.total_episodes
)
)
return render_template(
@ -233,7 +257,9 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
This file will be loaded by Dygraph javascript to plot data in real time."""
columns = []
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] == "float32"]
selected_columns = [
col for col, ft in dataset.features.items() if ft["dtype"] == "float32"
]
selected_columns.remove("timestamp")
# init header of csv with state and action names
@ -247,7 +273,10 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
)
header += [f"{column_name}_{i}" for i in range(dim_state)]
if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
if (
"names" in dataset.features[column_name]
and dataset.features[column_name]["names"]
):
column_names = dataset.features[column_name]["names"]
while not isinstance(column_names, list):
column_names = list(column_names.values())[0]
@ -268,8 +297,12 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
else:
repo_id = dataset.repo_id
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index
url = (
f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
+ dataset.data_path.format(
episode_chunk=int(episode_index) // dataset.chunks_size,
episode_index=episode_index,
)
)
df = pd.read_parquet(url)
data = df[selected_columns] # Select specific columns
@ -302,7 +335,9 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
]
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
def get_episode_language_instruction(
dataset: LeRobotDataset, ep_index: int
) -> list[str]:
# check if the dataset has language instructions
if "language_instruction" not in dataset.features:
return None
@ -313,11 +348,15 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) ->
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
# with the tf.tensor appearing in the string
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix(
"', shape=(), dtype=string)"
)
def get_dataset_info(repo_id: str) -> IterableNamespace:
response = requests.get(f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json")
response = requests.get(
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json"
)
response.raise_for_status() # Raises an HTTPError for bad responses
dataset_info = response.json()
dataset_info["repo_id"] = repo_id
@ -346,7 +385,9 @@ def visualize_dataset_html(
if force_override:
shutil.rmtree(output_dir)
else:
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
logging.info(
f"Output directory already exists. Loading from it: '{output_dir}'"
)
output_dir.mkdir(parents=True, exist_ok=True)

View File

@ -126,7 +126,12 @@ def patch_builtins_input(monkeypatch):
def pytest_addoption(parser):
parser.addoption("--seed", action="store", default="42", help="Set random seed for reproducibility")
parser.addoption(
"--seed",
action="store",
default="42",
help="Set random seed for reproducibility",
)
@pytest.fixture(autouse=True)

View File

@ -7,17 +7,39 @@ DUMMY_MOTOR_FEATURES = {
"action": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
"names": [
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
],
},
"state": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
"names": [
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
],
},
}
DUMMY_CAMERA_FEATURES = {
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
"laptop": {
"shape": (480, 640, 3),
"names": ["height", "width", "channels"],
"info": None,
},
"phone": {
"shape": (480, 640, 3),
"names": ["height", "width", "channels"],
"info": None,
},
}
DEFAULT_FPS = 30
DUMMY_VIDEO_INFO = {

View File

@ -8,7 +8,11 @@ import PIL.Image
import pytest
import torch
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.lerobot_dataset import (
CODEBASE_VERSION,
LeRobotDataset,
LeRobotDatasetMetadata,
)
from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_FEATURES,
@ -35,7 +39,9 @@ def get_task_index(task_dicts: dict, task: str) -> int:
@pytest.fixture(scope="session")
def img_tensor_factory():
def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor:
def _create_img_tensor(
height=100, width=100, channels=3, dtype=torch.float32
) -> torch.Tensor:
return torch.rand((channels, height, width), dtype=dtype)
return _create_img_tensor
@ -43,10 +49,14 @@ def img_tensor_factory():
@pytest.fixture(scope="session")
def img_array_factory():
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
def _create_img_array(
height=100, width=100, channels=3, dtype=np.uint8
) -> np.ndarray:
if np.issubdtype(dtype, np.unsignedinteger):
# Int array in [0, 255] range
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
img_array = np.random.randint(
0, 256, size=(height, width, channels), dtype=dtype
)
elif np.issubdtype(dtype, np.floating):
# Float array in [0, 1] range
img_array = np.random.rand(height, width, channels).astype(dtype)
@ -75,10 +85,13 @@ def features_factory():
) -> dict:
if use_videos:
camera_ft = {
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items()
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO}
for key, ft in camera_features.items()
}
else:
camera_ft = {key: {"dtype": "image", **ft} for key, ft in camera_features.items()}
camera_ft = {
key: {"dtype": "image", **ft} for key, ft in camera_features.items()
}
return {
**motor_features,
**camera_ft,
@ -177,7 +190,9 @@ def episodes_factory(tasks_factory):
if total_episodes <= 0 or total_frames <= 0:
raise ValueError("num_episodes and total_length must be positive integers.")
if total_frames < total_episodes:
raise ValueError("total_length must be greater than or equal to num_episodes.")
raise ValueError(
"total_length must be greater than or equal to num_episodes."
)
if not tasks:
min_tasks = 2 if multi_task else 1
@ -185,10 +200,14 @@ def episodes_factory(tasks_factory):
tasks = tasks_factory(total_tasks)
if total_episodes < len(tasks) and not multi_task:
raise ValueError("The number of tasks should be less than the number of episodes.")
raise ValueError(
"The number of tasks should be less than the number of episodes."
)
# Generate random lengths that sum up to total_length
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
lengths = np.random.multinomial(
total_frames, [1 / total_episodes] * total_episodes
).tolist()
tasks_list = [task_dict["task"] for task_dict in tasks]
num_tasks_available = len(tasks_list)
@ -196,9 +215,13 @@ def episodes_factory(tasks_factory):
episodes_list = []
remaining_tasks = tasks_list.copy()
for ep_idx in range(total_episodes):
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
num_tasks_in_episode = (
random.randint(1, min(3, num_tasks_available)) if multi_task else 1
)
tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
episode_tasks = random.sample(
tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample))
)
if remaining_tasks:
for task in episode_tasks:
remaining_tasks.remove(task)
@ -217,7 +240,9 @@ def episodes_factory(tasks_factory):
@pytest.fixture(scope="session")
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
def hf_dataset_factory(
features_factory, tasks_factory, episodes_factory, img_array_factory
):
def _create_hf_dataset(
features: dict | None = None,
tasks: list[dict] | None = None,
@ -236,13 +261,22 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
episode_index_col = np.array([], dtype=np.int64)
task_index = np.array([], dtype=np.int64)
for ep_dict in episodes:
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
timestamp_col = np.concatenate(
(timestamp_col, np.arange(ep_dict["length"]) / fps)
)
frame_index_col = np.concatenate(
(frame_index_col, np.arange(ep_dict["length"], dtype=int))
)
episode_index_col = np.concatenate(
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
(
episode_index_col,
np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int),
)
)
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
task_index = np.concatenate(
(task_index, np.full(ep_dict["length"], ep_task_index, dtype=int))
)
index_col = np.arange(len(episode_index_col))
@ -254,7 +288,9 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
for _ in range(len(index_col))
]
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"])
robot_cols[key] = np.random.random(
(len(index_col), ft["shape"][0])
).astype(ft["dtype"])
hf_features = get_hf_features_from_features(features)
dataset = datasets.Dataset.from_dict(
@ -299,7 +335,9 @@ def lerobot_dataset_metadata_factory(
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
tasks=tasks,
)
mock_snapshot_download = mock_snapshot_download_factory(
@ -316,10 +354,14 @@ def lerobot_dataset_metadata_factory(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
):
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version: version
mock_get_hub_safe_version_patch.side_effect = (
lambda repo_id, version: version
)
mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDatasetMetadata(repo_id=repo_id, root=root, local_files_only=local_files_only)
return LeRobotDatasetMetadata(
repo_id=repo_id, root=root, local_files_only=local_files_only
)
return _create_lerobot_dataset_metadata
@ -350,7 +392,9 @@ def lerobot_dataset_factory(
) -> LeRobotDataset:
if not info:
info = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
total_episodes=total_episodes,
total_frames=total_frames,
total_tasks=total_tasks,
)
if not stats:
stats = stats_factory(features=info["features"])
@ -364,7 +408,9 @@ def lerobot_dataset_factory(
multi_task=multi_task,
)
if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"])
hf_dataset = hf_dataset_factory(
tasks=tasks, episodes=episode_dicts, fps=info["fps"]
)
mock_snapshot_download = mock_snapshot_download_factory(
info=info,
@ -383,7 +429,9 @@ def lerobot_dataset_factory(
local_files_only=kwargs.get("local_files_only", False),
)
with (
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata"
) as mock_metadata_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,

View File

@ -7,7 +7,12 @@ import pyarrow.compute as pc
import pyarrow.parquet as pq
import pytest
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH
from lerobot.common.datasets.utils import (
EPISODES_PATH,
INFO_PATH,
STATS_PATH,
TASKS_PATH,
)
@pytest.fixture(scope="session")
@ -69,7 +74,10 @@ def episode_path(episodes_factory):
@pytest.fixture(scope="session")
def single_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_single_episode_parquet(
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
dir: Path,
ep_idx: int = 0,
hf_dataset: datasets.Dataset | None = None,
info: dict | None = None,
) -> Path:
if not info:
info = info_factory()

31
tests/fixtures/hub.py vendored
View File

@ -4,7 +4,12 @@ import datasets
import pytest
from huggingface_hub.utils import filter_repo_objects
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH
from lerobot.common.datasets.utils import (
EPISODES_PATH,
INFO_PATH,
STATS_PATH,
TASKS_PATH,
)
from tests.fixtures.constants import LEROBOT_TEST_DIR
@ -41,15 +46,21 @@ def mock_snapshot_download_factory(
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
tasks=tasks,
)
if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
hf_dataset = hf_dataset_factory(
tasks=tasks, episodes=episodes, fps=info["fps"]
)
def _extract_episode_index_from_path(fpath: str) -> int:
path = Path(fpath)
if path.suffix == ".parquet" and path.stem.startswith("episode_"):
episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0
episode_index = int(
path.stem[len("episode_") :]
) # 'episode_000000' -> 0
return episode_index
else:
return None
@ -74,12 +85,16 @@ def mock_snapshot_download_factory(
for episode_dict in episodes:
ep_idx = episode_dict["episode_index"]
ep_chunk = ep_idx // info["chunks_size"]
data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
data_path = info["data_path"].format(
episode_chunk=ep_chunk, episode_index=ep_idx
)
data_files.append(data_path)
all_files.extend(data_files)
allowed_files = filter_repo_objects(
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
all_files,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
# Create allowed files
@ -87,7 +102,9 @@ def mock_snapshot_download_factory(
if rel_path.startswith("data/"):
episode_index = _extract_episode_index_from_path(rel_path)
if episode_index is not None:
_ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info)
_ = single_episode_parquet_path(
local_dir, episode_index, hf_dataset, info
)
if rel_path == INFO_PATH:
_ = info_path(local_dir, info)
elif rel_path == STATS_PATH:

View File

@ -67,7 +67,9 @@ class GroupSyncRead:
def addParam(self, motor_index): # noqa: N802
# Initialize motor default values
if motor_index not in self.packet_handler.data:
self.packet_handler.data[motor_index] = get_default_motor_values(motor_index)
self.packet_handler.data[motor_index] = get_default_motor_values(
motor_index
)
def txRxPacket(self): # noqa: N802
return COMM_SUCCESS

View File

@ -17,7 +17,9 @@ class config: # noqa: N801
def enable_device(self, device_id: str):
self.device_enabled = device_id
def enable_stream(self, stream_type: stream, width=None, height=None, color_format=None, fps=None):
def enable_stream(
self, stream_type: stream, width=None, height=None, color_format=None, fps=None
):
self.stream_type = stream_type
# Overwrite default values when possible
self.width = 848 if width is None else width

View File

@ -78,7 +78,9 @@ class GroupSyncRead:
def addParam(self, motor_index): # noqa: N802
# Initialize motor default values
if motor_index not in self.packet_handler.data:
self.packet_handler.data[motor_index] = get_default_motor_values(motor_index)
self.packet_handler.data[motor_index] = get_default_motor_values(
motor_index
)
def txRxPacket(self): # noqa: N802
return COMM_SUCCESS

View File

@ -25,7 +25,10 @@ from torchmetrics import AUROC, Accuracy, F1Score, Precision, Recall
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier, ClassifierConfig
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
ClassifierConfig,
)
BATCH_SIZE = 1000
LR = 0.1
@ -43,7 +46,9 @@ def train_evaluate_multiclass_classifier():
logging.info(
f"Start multiclass classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}"
)
multiclass_config = ClassifierConfig(model_name="microsoft/resnet-18", device=DEVICE, num_classes=10)
multiclass_config = ClassifierConfig(
model_name="microsoft/resnet-18", device=DEVICE, num_classes=10
)
multiclass_classifier = Classifier(multiclass_config)
trainset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
@ -114,10 +119,18 @@ def train_evaluate_multiclass_classifier():
test_probs = torch.stack(test_probs)
accuracy = Accuracy(task="multiclass", num_classes=multiclass_num_classes)
precision = Precision(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
recall = Recall(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
f1 = F1Score(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
auroc = AUROC(task="multiclass", num_classes=multiclass_num_classes, average="weighted")
precision = Precision(
task="multiclass", average="weighted", num_classes=multiclass_num_classes
)
recall = Recall(
task="multiclass", average="weighted", num_classes=multiclass_num_classes
)
f1 = F1Score(
task="multiclass", average="weighted", num_classes=multiclass_num_classes
)
auroc = AUROC(
task="multiclass", num_classes=multiclass_num_classes, average="weighted"
)
# Calculate metrics
acc = accuracy(test_predictions, test_labels)
@ -146,18 +159,28 @@ def train_evaluate_binary_classifier():
new_label = float(1.0) if label == target_class else float(0.0)
new_targets.append(new_label)
dataset.targets = new_targets # Replace the original labels with the binary ones
dataset.targets = (
new_targets # Replace the original labels with the binary ones
)
return dataset
binary_train_dataset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
binary_test_dataset = CIFAR10(root="data", train=False, download=True, transform=ToTensor())
binary_train_dataset = CIFAR10(
root="data", train=True, download=True, transform=ToTensor()
)
binary_test_dataset = CIFAR10(
root="data", train=False, download=True, transform=ToTensor()
)
# Apply one-vs-rest labeling
binary_train_dataset = one_vs_rest(binary_train_dataset, target_binary_class)
binary_test_dataset = one_vs_rest(binary_test_dataset, target_binary_class)
binary_trainloader = DataLoader(binary_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
binary_testloader = DataLoader(binary_test_dataset, batch_size=BATCH_SIZE, shuffle=False)
binary_trainloader = DataLoader(
binary_train_dataset, batch_size=BATCH_SIZE, shuffle=True
)
binary_testloader = DataLoader(
binary_test_dataset, batch_size=BATCH_SIZE, shuffle=False
)
binary_epoch = 1

View File

@ -9,7 +9,9 @@ from tests.utils import require_package
def test_classifier_output():
output = ClassifierOutput(
logits=torch.tensor([1, 2, 3]), probabilities=torch.tensor([0.1, 0.2, 0.3]), hidden_states=None
logits=torch.tensor([1, 2, 3]),
probabilities=torch.tensor([0.1, 0.2, 0.3]),
hidden_states=None,
)
assert (
@ -20,7 +22,9 @@ def test_classifier_output():
@require_package("transformers")
def test_binary_classifier_with_default_params():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
)
config = ClassifierConfig()
classifier = Classifier(config)
@ -41,7 +45,9 @@ def test_binary_classifier_with_default_params():
@require_package("transformers")
def test_multiclass_classifier():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
)
num_classes = 5
config = ClassifierConfig(num_classes=num_classes)
@ -63,7 +69,9 @@ def test_multiclass_classifier():
@require_package("transformers")
def test_default_device():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
)
config = ClassifierConfig()
assert config.device == "cpu"
@ -75,7 +83,9 @@ def test_default_device():
@require_package("transformers")
def test_explicit_device_setup():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
)
config = ClassifierConfig(device="meta")
assert config.device == "meta"

View File

@ -52,7 +52,13 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
# save 2 frames at the middle of first episode
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
i = int(
(
dataset.episode_data_index["to"][0].item()
- dataset.episode_data_index["from"][0].item()
)
/ 2
)
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
@ -88,4 +94,6 @@ if __name__ == "__main__":
"lerobot/nyu_franka_play_dataset",
"lerobot/cmu_stretch",
]:
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors", repo_id=dataset)
save_dataset_to_safetensors(
"tests/data/save_dataset_to_safetensors", repo_id=dataset
)

View File

@ -70,7 +70,9 @@ def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwa
param_stats = {}
for key, param in policy.named_parameters():
param_stats[f"{key}_mean"] = param.mean()
param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0))
param_stats[f"{key}_std"] = (
param.std() if param.numel() > 1 else torch.tensor(float(0.0))
)
optimizer.zero_grad()
policy.reset()
@ -97,7 +99,9 @@ def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwa
else:
actions_queue = train_cfg.policy.n_action_repeats
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
actions = {
str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)
}
return output_dict, grad_stats, param_stats, actions

View File

@ -24,7 +24,10 @@ pytest -sx 'tests/test_cameras.py::test_camera[intelrealsense-True]'
import numpy as np
import pytest
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
)
from tests.utils import TEST_CAMERA_TYPES, make_camera, require_camera
# Maximum absolute difference between two consecutive images recored by a camera.
@ -99,7 +102,11 @@ def test_camera(request, camera_type, mock):
)
# TODO(rcadene): properly set `rtol`
np.testing.assert_allclose(
color_image, async_color_image, rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE, err_msg=error_msg
color_image,
async_color_image,
rtol=1e-5,
atol=MAX_PIXEL_DIFFERENCE,
err_msg=error_msg,
)
# Test disconnecting
@ -118,7 +125,11 @@ def test_camera(request, camera_type, mock):
assert camera.color_mode == "bgr"
bgr_color_image = camera.read()
np.testing.assert_allclose(
color_image, bgr_color_image[:, :, [2, 1, 0]], rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE, err_msg=error_msg
color_image,
bgr_color_image[:, :, [2, 1, 0]],
rtol=1e-5,
atol=MAX_PIXEL_DIFFERENCE,
err_msg=error_msg,
)
del camera
@ -153,7 +164,11 @@ def test_camera(request, camera_type, mock):
rot_color_image = camera.read()
np.testing.assert_allclose(
rot_color_image, manual_rot_img, rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE, err_msg=error_msg
rot_color_image,
manual_rot_img,
rtol=1e-5,
atol=MAX_PIXEL_DIFFERENCE,
err_msg=error_msg,
)
del camera
@ -187,7 +202,9 @@ def test_save_images_from_cameras(tmpdir, request, camera_type, mock):
if camera_type == "opencv":
from lerobot.common.robot_devices.cameras.opencv import save_images_from_cameras
elif camera_type == "intelrealsense":
from lerobot.common.robot_devices.cameras.intelrealsense import save_images_from_cameras
from lerobot.common.robot_devices.cameras.intelrealsense import (
save_images_from_cameras,
)
# Small `record_time_s` to speedup unit tests
save_images_from_cameras(tmpdir, record_time_s=0.02, mock=mock)

View File

@ -335,8 +335,12 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
)
dataset = record(robot, rec_cfg)
assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False"
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
assert not mock_events[
"rerecord_episode"
], "`rerecord_episode` wasn't properly reset to False"
assert not mock_events[
"exit_early"
], "`exit_early` wasn't properly reset to False"
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
@ -392,7 +396,8 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
@pytest.mark.parametrize(
"robot_type, mock, num_image_writer_processes", [("koch", True, 0), ("koch", True, 1)]
"robot_type, mock, num_image_writer_processes",
[("koch", True, 0), ("koch", True, 1)],
)
@require_robot
def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num_image_writer_processes):

View File

@ -61,7 +61,9 @@ def test_same_attributes_defined(lerobot_dataset_factory, tmp_path):
# Instantiate both ways
robot = make_robot("koch", mock=True)
root_create = tmp_path / "create"
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
dataset_create = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create
)
root_init = tmp_path / "init"
dataset_init = lerobot_dataset_factory(root=root_init)
@ -226,7 +228,9 @@ def test_compute_stats_on_xarm():
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
# dataset into even batches.
computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25), num_workers=0)
computed_stats = compute_stats(
dataset, batch_size=int(len(dataset) * 0.25), num_workers=0
)
# get einops patterns to aggregate batches and compute statistics
stats_patterns = get_stats_einops_patterns(dataset)
@ -247,7 +251,9 @@ def test_compute_stats_on_xarm():
expected_stats[k] = {}
expected_stats[k]["mean"] = einops.reduce(full_batch[k], pattern, "mean")
expected_stats[k]["std"] = torch.sqrt(
einops.reduce((full_batch[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
einops.reduce(
(full_batch[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean"
)
)
expected_stats[k]["min"] = einops.reduce(full_batch[k], pattern, "min")
expected_stats[k]["max"] = einops.reduce(full_batch[k], pattern, "max")
@ -292,7 +298,9 @@ def test_flatten_unflatten_dict():
d = unflatten_dict(flatten_dict(d))
# test equality between nested dicts
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
assert json.dumps(original_d, sort_keys=True) == json.dumps(
d, sort_keys=True
), f"{original_d} != {d}"
@pytest.mark.parametrize(
@ -346,7 +354,13 @@ def test_backward_compatibility(repo_id):
load_and_compare(i + 1)
# test 2 frames at the middle of first episode
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
i = int(
(
dataset.episode_data_index["to"][0].item()
- dataset.episode_data_index["from"][0].item()
)
/ 2
)
load_and_compare(i)
load_and_compare(i + 1)
@ -383,23 +397,40 @@ def test_multidataset_aggregate_stats():
data_c = torch.rand(20, dtype=torch.float32)
hf_dataset_1 = Dataset.from_dict(
{"a": data_a[:10], "b": data_b[:10], "c": data_c[:10], "index": torch.arange(10)}
{
"a": data_a[:10],
"b": data_b[:10],
"c": data_c[:10],
"index": torch.arange(10),
}
)
hf_dataset_1.set_transform(hf_transform_to_torch)
hf_dataset_2 = Dataset.from_dict({"a": data_a[10:20], "b": data_b[10:], "index": torch.arange(10)})
hf_dataset_2 = Dataset.from_dict(
{"a": data_a[10:20], "b": data_b[10:], "index": torch.arange(10)}
)
hf_dataset_2.set_transform(hf_transform_to_torch)
hf_dataset_3 = Dataset.from_dict({"a": data_a[20:], "c": data_c[10:], "index": torch.arange(10)})
hf_dataset_3 = Dataset.from_dict(
{"a": data_a[20:], "c": data_c[10:], "index": torch.arange(10)}
)
hf_dataset_3.set_transform(hf_transform_to_torch)
dataset_1 = LeRobotDataset.from_preloaded("d1", hf_dataset=hf_dataset_1)
dataset_1.stats = compute_stats(dataset_1, batch_size=len(hf_dataset_1), num_workers=0)
dataset_1.stats = compute_stats(
dataset_1, batch_size=len(hf_dataset_1), num_workers=0
)
dataset_2 = LeRobotDataset.from_preloaded("d2", hf_dataset=hf_dataset_2)
dataset_2.stats = compute_stats(dataset_2, batch_size=len(hf_dataset_2), num_workers=0)
dataset_2.stats = compute_stats(
dataset_2, batch_size=len(hf_dataset_2), num_workers=0
)
dataset_3 = LeRobotDataset.from_preloaded("d3", hf_dataset=hf_dataset_3)
dataset_3.stats = compute_stats(dataset_3, batch_size=len(hf_dataset_3), num_workers=0)
dataset_3.stats = compute_stats(
dataset_3, batch_size=len(hf_dataset_3), num_workers=0
)
stats = aggregate_stats([dataset_1, dataset_2, dataset_3])
for data_key, data in zip(["a", "b", "c"], [data_a, data_b, data_c], strict=True):
for agg_fn in ["mean", "min", "max"]:
assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn))
assert torch.allclose(
stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn)
)
assert torch.allclose(stats[data_key]["std"], torch.std(data, correction=0))

View File

@ -22,13 +22,17 @@ def synced_hf_dataset_factory(hf_dataset_factory):
@pytest.fixture(scope="module")
def unsynced_hf_dataset_factory(synced_hf_dataset_factory):
def _create_unsynced_hf_dataset(fps: int = 30, tolerance_s: float = 1e-4) -> Dataset:
def _create_unsynced_hf_dataset(
fps: int = 30, tolerance_s: float = 1e-4
) -> Dataset:
hf_dataset = synced_hf_dataset_factory(fps=fps)
features = hf_dataset.features
df = hf_dataset.to_pandas()
dtype = df["timestamp"].dtype # This is to avoid pandas type warning
# Modify a single timestamp just outside tolerance
df.at[30, "timestamp"] = dtype.type(df.at[30, "timestamp"] + (tolerance_s * 1.1))
df.at[30, "timestamp"] = dtype.type(
df.at[30, "timestamp"] + (tolerance_s * 1.1)
)
unsynced_hf_dataset = Dataset.from_pandas(df, features=features)
unsynced_hf_dataset.set_transform(hf_transform_to_torch)
return unsynced_hf_dataset
@ -38,13 +42,17 @@ def unsynced_hf_dataset_factory(synced_hf_dataset_factory):
@pytest.fixture(scope="module")
def slightly_off_hf_dataset_factory(synced_hf_dataset_factory):
def _create_slightly_off_hf_dataset(fps: int = 30, tolerance_s: float = 1e-4) -> Dataset:
def _create_slightly_off_hf_dataset(
fps: int = 30, tolerance_s: float = 1e-4
) -> Dataset:
hf_dataset = synced_hf_dataset_factory(fps=fps)
features = hf_dataset.features
df = hf_dataset.to_pandas()
dtype = df["timestamp"].dtype # This is to avoid pandas type warning
# Modify a single timestamp just inside tolerance
df.at[30, "timestamp"] = dtype.type(df.at[30, "timestamp"] + (tolerance_s * 0.9))
df.at[30, "timestamp"] = dtype.type(
df.at[30, "timestamp"] + (tolerance_s * 0.9)
)
unsynced_hf_dataset = Dataset.from_pandas(df, features=features)
unsynced_hf_dataset.set_transform(hf_transform_to_torch)
return unsynced_hf_dataset
@ -158,7 +166,9 @@ def test_check_timestamps_sync_slightly_off(slightly_off_hf_dataset_factory):
def test_check_timestamps_sync_single_timestamp():
single_timestamp_hf_dataset = Dataset.from_dict({"timestamp": [0.0], "episode_index": [0]})
single_timestamp_hf_dataset = Dataset.from_dict(
{"timestamp": [0.0], "episode_index": [0]}
)
single_timestamp_hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {"to": torch.tensor([1]), "from": torch.tensor([0])}
fps = 30
@ -207,7 +217,9 @@ def test_check_delta_timestamps_valid(valid_delta_timestamps_factory):
def test_check_delta_timestamps_slightly_off(slightly_off_delta_timestamps_factory):
fps = 30
tolerance_s = 1e-4
slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory(fps, tolerance_s)
slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory(
fps, tolerance_s
)
result = check_delta_timestamps(
delta_timestamps=slightly_off_delta_timestamps,
fps=fps,

View File

@ -364,10 +364,16 @@ def test_save_each_transform(img_tensor_factory, tmp_path):
for transform in transforms:
transform_dir = tmp_path / transform
assert transform_dir.exists(), f"{transform} directory was not created."
assert any(transform_dir.iterdir()), f"No transformed images found in {transform} directory."
assert any(
transform_dir.iterdir()
), f"No transformed images found in {transform} directory."
# Check for specific files within each transform directory
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"]
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + [
"min.png",
"max.png",
"mean.png",
]
for file_name in expected_files:
assert (transform_dir / file_name).exists(), (
f"{file_name} was not found in {transform} directory."

View File

@ -160,7 +160,9 @@ def test_save_image_torch(tmp_path, img_tensor_factory):
writer.wait_until_done()
assert fpath.exists()
saved_image = np.array(Image.open(fpath))
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(
np.uint8
)
assert np.array_equal(expected_image, saved_image)
finally:
writer.stop()
@ -175,7 +177,9 @@ def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory):
writer.wait_until_done()
assert fpath.exists()
saved_image = np.array(Image.open(fpath))
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(
np.uint8
)
assert np.array_equal(expected_image, saved_image)
finally:
writer.stop()
@ -265,7 +269,9 @@ def test_wait_until_done(tmp_path, img_array_factory):
writer = AsyncImageWriter(num_processes=0, num_threads=4)
try:
num_images = 100
image_arrays = [img_array_factory(height=500, width=500) for _ in range(num_images)]
image_arrays = [
img_array_factory(height=500, width=500) for _ in range(num_images)
]
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
fpath.parent.mkdir(parents=True, exist_ok=True)

View File

@ -30,7 +30,10 @@ import time
import numpy as np
import pytest
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
)
from lerobot.scripts.find_motors_bus_port import find_port
from tests.utils import TEST_MOTOR_TYPES, make_motors_bus, require_motor
@ -63,7 +66,9 @@ def test_configure_motors_all_ids_1(request, motor_type, mock):
else:
raise ValueError(motor_type)
input("Are you sure you want to re-configure the motors? Press enter to continue...")
input(
"Are you sure you want to re-configure the motors? Press enter to continue..."
)
# This test expect the configuration was already correct.
motors_bus = make_motors_bus(motor_type, mock=mock)
motors_bus.connect()

View File

@ -44,13 +44,23 @@ def make_new_buffer(
return buffer, write_dir
def make_spoof_data_frames(n_episodes: int, n_frames_per_episode: int) -> dict[str, np.ndarray]:
def make_spoof_data_frames(
n_episodes: int, n_frames_per_episode: int
) -> dict[str, np.ndarray]:
new_data = {
data_key: np.arange(n_frames_per_episode * n_episodes * np.prod(data_shape)).reshape(-1, *data_shape),
data_key: np.arange(
n_frames_per_episode * n_episodes * np.prod(data_shape)
).reshape(-1, *data_shape),
OnlineBuffer.INDEX_KEY: np.arange(n_frames_per_episode * n_episodes),
OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(np.arange(n_episodes), n_frames_per_episode),
OnlineBuffer.FRAME_INDEX_KEY: np.tile(np.arange(n_frames_per_episode), n_episodes),
OnlineBuffer.TIMESTAMP_KEY: np.tile(np.arange(n_frames_per_episode) / fps, n_episodes),
OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(
np.arange(n_episodes), n_frames_per_episode
),
OnlineBuffer.FRAME_INDEX_KEY: np.tile(
np.arange(n_frames_per_episode), n_episodes
),
OnlineBuffer.TIMESTAMP_KEY: np.tile(
np.arange(n_frames_per_episode) / fps, n_episodes
),
}
return new_data
@ -166,7 +176,9 @@ def test_delta_timestamps_within_tolerance():
buffer.tolerance_s = 0.04
item = buffer[2]
data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"]
assert torch.allclose(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
assert torch.allclose(
data, torch.tensor([0, 2, 3])
), "Data does not match expected values"
assert not is_pad.any(), "Unexpected padding detected"
@ -219,58 +231,89 @@ def test_compute_sampler_weights_trivial(
online_dataset_size: int,
online_sampling_ratio: float,
):
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=offline_dataset_size)
offline_dataset = lerobot_dataset_factory(
tmp_path, total_episodes=1, total_frames=offline_dataset_size
)
online_dataset, _ = make_new_buffer()
if online_dataset_size > 0:
online_dataset.add_data(
make_spoof_data_frames(n_episodes=2, n_frames_per_episode=online_dataset_size // 2)
make_spoof_data_frames(
n_episodes=2, n_frames_per_episode=online_dataset_size // 2
)
)
weights = compute_sampler_weights(
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
offline_dataset,
online_dataset=online_dataset,
online_sampling_ratio=online_sampling_ratio,
)
if offline_dataset_size == 0 or online_dataset_size == 0:
expected_weights = torch.ones(offline_dataset_size + online_dataset_size)
elif online_sampling_ratio == 0:
expected_weights = torch.cat([torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)])
expected_weights = torch.cat(
[torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)]
)
elif online_sampling_ratio == 1:
expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)])
expected_weights = torch.cat(
[torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)]
)
expected_weights /= expected_weights.sum()
assert torch.allclose(weights, expected_weights)
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path):
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
offline_dataset = lerobot_dataset_factory(
tmp_path, total_episodes=1, total_frames=4
)
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
online_dataset.add_data(
make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)
)
online_sampling_ratio = 0.8
weights = compute_sampler_weights(
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
offline_dataset,
online_dataset=online_dataset,
online_sampling_ratio=online_sampling_ratio,
)
assert torch.allclose(
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
weights,
torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]),
)
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_dataset_factory, tmp_path):
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(
lerobot_dataset_factory, tmp_path
):
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
offline_dataset = lerobot_dataset_factory(
tmp_path, total_episodes=1, total_frames=4
)
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
online_dataset.add_data(
make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)
)
weights = compute_sampler_weights(
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1
offline_dataset,
online_dataset=online_dataset,
online_sampling_ratio=0.8,
online_drop_n_last_frames=1,
)
assert torch.allclose(
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0])
weights,
torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0]),
)
def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path):
"""Note: test copied from test_sampler."""
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=2)
offline_dataset = lerobot_dataset_factory(
tmp_path, total_episodes=1, total_frames=2
)
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
online_dataset.add_data(
make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)
)
weights = compute_sampler_weights(
offline_dataset,
@ -279,4 +322,6 @@ def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp
online_sampling_ratio=0.5,
online_drop_n_last_frames=1,
)
assert torch.allclose(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0]))
assert torch.allclose(
weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0])
)

View File

@ -173,7 +173,9 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
# Test updating the policy (and test that it does not mutate the batch)
batch_ = deepcopy(batch)
policy.forward(batch)
assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass."
assert set(batch) == set(
batch_
), "Batch keys are not the same after a forward pass."
assert all(
torch.equal(batch[k], batch_[k]) if isinstance(batch[k], torch.Tensor) else batch[k] == batch_[k]
for k in batch
@ -187,7 +189,9 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
observation = preprocess_observation(observation)
# send observation to device/gpu
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
observation = {
key: observation[key].to(DEVICE, non_blocking=True) for key in observation
}
# get the next action for the environment (also check that the observation batch is not modified)
observation_ = deepcopy(observation)
@ -417,7 +421,8 @@ def test_backward_compatibility(
6. Remember to stage and commit the resulting changes to `tests/data`.
"""
env_policy_dir = (
Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}{file_name_extra}"
Path("tests/data/save_policy_to_safetensors")
/ f"{env_name}_{policy_name}{file_name_extra}"
)
saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors")
saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors")
@ -461,7 +466,9 @@ def test_act_temporal_ensembler():
batch_size = batch_seq.shape[0]
# Exponential weighting (normalized). Unsqueeze once to match the position of the `episode_length`
# dimension of `batch_seq`.
weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(-1)
weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(
-1
)
# Simulate stepping through a rollout and computing a batch of actions with model on each step.
for i in range(episode_length):
@ -484,7 +491,8 @@ def test_act_temporal_ensembler():
episode_step_indices = torch.arange(i + 1)[-len(chunk_indices) :]
seq_slice = batch_seq[:, episode_step_indices, chunk_indices]
offline_avg = (
einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum") / weights[: i + 1].sum()
einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum")
/ weights[: i + 1].sum()
)
# Sanity check. The average should be between the extrema.
assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg)

View File

@ -31,7 +31,11 @@ def _mock_download_raw_pusht(raw_dir, num_frames=4, num_episodes=3):
zarr_data = zarr.group(store=store)
zarr_data.create_dataset(
"data/action", shape=(num_frames, 1), chunks=(num_frames, 1), dtype=np.float32, overwrite=True
"data/action",
shape=(num_frames, 1),
chunks=(num_frames, 1),
dtype=np.float32,
overwrite=True,
)
zarr_data.create_dataset(
"data/img",
@ -41,20 +45,38 @@ def _mock_download_raw_pusht(raw_dir, num_frames=4, num_episodes=3):
overwrite=True,
)
zarr_data.create_dataset(
"data/n_contacts", shape=(num_frames, 2), chunks=(num_frames, 2), dtype=np.float32, overwrite=True
"data/n_contacts",
shape=(num_frames, 2),
chunks=(num_frames, 2),
dtype=np.float32,
overwrite=True,
)
zarr_data.create_dataset(
"data/state", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
"data/state",
shape=(num_frames, 5),
chunks=(num_frames, 5),
dtype=np.float32,
overwrite=True,
)
zarr_data.create_dataset(
"data/keypoint", shape=(num_frames, 9, 2), chunks=(num_frames, 9, 2), dtype=np.float32, overwrite=True
"data/keypoint",
shape=(num_frames, 9, 2),
chunks=(num_frames, 9, 2),
dtype=np.float32,
overwrite=True,
)
zarr_data.create_dataset(
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
"meta/episode_ends",
shape=(num_episodes,),
chunks=(num_episodes,),
dtype=np.int32,
overwrite=True,
)
zarr_data["data/action"][:] = np.random.randn(num_frames, 1)
zarr_data["data/img"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
zarr_data["data/img"][:] = np.random.randint(
0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8
)
zarr_data["data/n_contacts"][:] = np.random.randn(num_frames, 2)
zarr_data["data/state"][:] = np.random.randn(num_frames, 5)
zarr_data["data/keypoint"][:] = np.random.randn(num_frames, 9, 2)
@ -93,7 +115,11 @@ def _mock_download_raw_umi(raw_dir, num_frames=4, num_episodes=3):
overwrite=True,
)
zarr_data.create_dataset(
"data/robot0_eef_pos", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
"data/robot0_eef_pos",
shape=(num_frames, 5),
chunks=(num_frames, 5),
dtype=np.float32,
overwrite=True,
)
zarr_data.create_dataset(
"data/robot0_eef_rot_axis_angle",
@ -110,10 +136,16 @@ def _mock_download_raw_umi(raw_dir, num_frames=4, num_episodes=3):
overwrite=True,
)
zarr_data.create_dataset(
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
"meta/episode_ends",
shape=(num_episodes,),
chunks=(num_episodes,),
dtype=np.int32,
overwrite=True,
)
zarr_data["data/camera0_rgb"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
zarr_data["data/camera0_rgb"][:] = np.random.randint(
0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8
)
zarr_data["data/robot0_demo_end_pose"][:] = np.random.randn(num_frames, 5)
zarr_data["data/robot0_demo_start_pose"][:] = np.random.randn(num_frames, 5)
zarr_data["data/robot0_eef_pos"][:] = np.random.randn(num_frames, 5)
@ -129,7 +161,9 @@ def _mock_download_raw_xarm(raw_dir, num_frames=4):
dataset_dict = {
"observations": {
"rgb": np.random.randint(0, 255, size=(num_frames, 3, 84, 84), dtype=np.uint8),
"rgb": np.random.randint(
0, 255, size=(num_frames, 3, 84, 84), dtype=np.uint8
),
"state": np.random.randn(num_frames, 4),
},
"actions": np.random.randn(num_frames, 3),
@ -151,13 +185,24 @@ def _mock_download_raw_aloha(raw_dir, num_frames=6, num_episodes=3):
raw_dir.mkdir(parents=True, exist_ok=True)
path_h5 = raw_dir / f"episode_{ep_idx}.hdf5"
with h5py.File(str(path_h5), "w") as f:
f.create_dataset("action", data=np.random.randn(num_frames // num_episodes, 14))
f.create_dataset("observations/qpos", data=np.random.randn(num_frames // num_episodes, 14))
f.create_dataset("observations/qvel", data=np.random.randn(num_frames // num_episodes, 14))
f.create_dataset(
"action", data=np.random.randn(num_frames // num_episodes, 14)
)
f.create_dataset(
"observations/qpos",
data=np.random.randn(num_frames // num_episodes, 14),
)
f.create_dataset(
"observations/qvel",
data=np.random.randn(num_frames // num_episodes, 14),
)
f.create_dataset(
"observations/images/top",
data=np.random.randint(
0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8
0,
255,
size=(num_frames // num_episodes, 480, 640, 3),
dtype=np.uint8,
),
)
@ -191,7 +236,12 @@ def _mock_download_raw_dora(raw_dir, num_frames=6, num_episodes=3, fps=30):
action = np.random.randn(21).tolist()
state = np.random.randn(21).tolist()
ep_idx = episode_indices_mapping[i]
frame = [{"path": f"videos/{cam_key}_episode_{ep_idx:06d}.mp4", "timestamp": frame_idx / fps}]
frame = [
{
"path": f"videos/{cam_key}_episode_{ep_idx:06d}.mp4",
"timestamp": frame_idx / fps,
}
]
timestamps.append(t_utc)
actions.append(action)
states.append(state)
@ -204,7 +254,9 @@ def _mock_download_raw_dora(raw_dir, num_frames=6, num_episodes=3, fps=30):
# write fake mp4 file for each episode
for ep_idx in range(num_episodes):
imgs_array = np.random.randint(0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8)
imgs_array = np.random.randint(
0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8
)
tmp_imgs_dir = raw_dir / "tmp_images"
save_images_concurrently(imgs_array, tmp_imgs_dir)
@ -265,7 +317,9 @@ def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
],
)
@require_package_arg
def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id, make_test_data):
def test_push_dataset_to_hub_format(
required_packages, tmpdir, raw_format, repo_id, make_test_data
):
num_episodes = 3
tmpdir = Path(tmpdir)
@ -317,7 +371,10 @@ def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_
== lerobot_dataset.hf_dataset["episode_index"][:num_frames]
)
for k in ["from", "to"]:
assert torch.equal(test_dataset.episode_data_index[k], lerobot_dataset.episode_data_index[k][:1])
assert torch.equal(
test_dataset.episode_data_index[k],
lerobot_dataset.episode_data_index[k][:1],
)
@pytest.mark.skip("push_dataset_to_hub is deprecated")
@ -359,8 +416,12 @@ def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, re
assert item1.keys() == item2.keys(), "Keys mismatch"
for key in item1:
if isinstance(item1[key], torch.Tensor) and isinstance(item2[key], torch.Tensor):
assert torch.equal(item1[key], item2[key]), f"Mismatch found in key: {key}"
if isinstance(item1[key], torch.Tensor) and isinstance(
item2[key], torch.Tensor
):
assert torch.equal(
item1[key], item2[key]
), f"Mismatch found in key: {key}"
else:
assert item1[key] == item2[key], f"Mismatch found in key: {key}"

View File

@ -95,7 +95,9 @@ def test_robot(tmpdir, request, robot_type, mock):
assert "observation.state" in observation
assert isinstance(observation["observation.state"], torch.Tensor)
assert observation["observation.state"].ndim == 1
dim_state = sum(len(robot.follower_arms[name].motors) for name in robot.follower_arms)
dim_state = sum(
len(robot.follower_arms[name].motors) for name in robot.follower_arms
)
assert observation["observation.state"].shape[0] == dim_state
# Cameras
for name in robot.cameras:
@ -106,7 +108,9 @@ def test_robot(tmpdir, request, robot_type, mock):
assert "action" in action
assert isinstance(action["action"], torch.Tensor)
assert action["action"].ndim == 1
dim_action = sum(len(robot.follower_arms[name].motors) for name in robot.follower_arms)
dim_action = sum(
len(robot.follower_arms[name].motors) for name in robot.follower_arms
)
assert action["action"].shape[0] == dim_action
# TODO(rcadene): test if observation and action data are returned as expected

View File

@ -15,7 +15,9 @@
# limitations under the License.
from datasets import Dataset
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
)
from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import (
hf_transform_to_torch,

View File

@ -9,7 +9,9 @@ from hydra import compose, initialize_config_dir
from torch import nn
from torch.utils.data import Dataset
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
ClassifierConfig,
)
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.scripts.train_hilserl_classifier import (
create_balanced_sampler,
@ -34,7 +36,9 @@ class MockDataset(Dataset):
def make_dummy_model():
model_config = ClassifierConfig(
num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=1
num_classes=2,
model_name="hf-tiny-model-private/tiny-random-ResNetModel",
num_cameras=1,
)
model = Classifier(config=model_config)
return model
@ -65,7 +69,9 @@ def test_create_balanced_sampler():
labels = [item["label"] for item in data]
class_counts = torch.tensor([labels.count(0), labels.count(1)], dtype=torch.float32)
class_weights = 1.0 / class_counts
expected_weights = torch.tensor([class_weights[label] for label in labels], dtype=torch.float32)
expected_weights = torch.tensor(
[class_weights[label] for label in labels], dtype=torch.float32
)
# Test that the weights are correct
assert torch.allclose(weights, expected_weights)
@ -149,7 +155,9 @@ def test_validate():
def test_train_epoch_multiple_cameras():
model_config = ClassifierConfig(
num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=2
num_classes=2,
model_name="hf-tiny-model-private/tiny-random-ResNetModel",
num_cameras=2,
)
model = Classifier(config=model_config)
@ -216,10 +224,16 @@ def test_resume_function(
):
# Initialize Hydra
test_file_dir = os.path.dirname(os.path.abspath(__file__))
config_dir = os.path.abspath(os.path.join(test_file_dir, "..", "lerobot", "configs", "policy"))
assert os.path.exists(config_dir), f"Config directory does not exist at {config_dir}"
config_dir = os.path.abspath(
os.path.join(test_file_dir, "..", "lerobot", "configs", "policy")
)
assert os.path.exists(
config_dir
), f"Config directory does not exist at {config_dir}"
with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"):
with initialize_config_dir(
config_dir=config_dir, job_name="test_app", version_base="1.2"
):
cfg = compose(
config_name="hilserl_classifier",
overrides=[
@ -244,7 +258,9 @@ def test_resume_function(
mock_init_hydra_config.return_value = cfg
# Mock dataset
dataset = MockDataset([{"image": torch.rand(3, 224, 224), "label": i % 2} for i in range(10)])
dataset = MockDataset(
[{"image": torch.rand(3, 224, 224), "label": i % 2} for i in range(10)]
)
mock_dataset.return_value = dataset
# Mock checkpoint handling

View File

@ -1,7 +1,9 @@
import torch
from datasets import Dataset
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
)
from lerobot.common.datasets.utils import (
hf_transform_to_torch,
)

View File

@ -47,7 +47,9 @@ for motor_type in available_motors:
OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0))
INTELREALSENSE_SERIAL_NUMBER = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614))
DYNAMIXEL_PORT = os.environ.get("LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081")
DYNAMIXEL_PORT = os.environ.get(
"LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081"
)
DYNAMIXEL_MOTORS = {
"shoulder_pan": [1, "xl430-w250"],
"shoulder_lift": [2, "xl430-w250"],
@ -57,7 +59,9 @@ DYNAMIXEL_MOTORS = {
"gripper": [6, "xl330-m288"],
}
FEETECH_PORT = os.environ.get("LEROBOT_TEST_FEETECH_PORT", "/dev/tty.usbmodem585A0080971")
FEETECH_PORT = os.environ.get(
"LEROBOT_TEST_FEETECH_PORT", "/dev/tty.usbmodem585A0080971"
)
FEETECH_MOTORS = {
"shoulder_pan": [1, "sts3215"],
"shoulder_lift": [2, "sts3215"],
@ -156,9 +160,13 @@ def require_package_arg(func):
if "required_packages" in arg_names:
# Get the index of 'required_packages' and retrieve the value from args
index = arg_names.index("required_packages")
required_packages = args[index] if len(args) > index else kwargs.get("required_packages")
required_packages = (
args[index] if len(args) > index else kwargs.get("required_packages")
)
else:
raise ValueError("Function does not have 'required_packages' as an argument.")
raise ValueError(
"Function does not have 'required_packages' as an argument."
)
if required_packages is None:
return func(*args, **kwargs)
@ -215,11 +223,17 @@ def require_robot(func):
mock = kwargs.get("mock")
if robot_type is None:
raise ValueError("The 'robot_type' must be an argument of the test function.")
raise ValueError(
"The 'robot_type' must be an argument of the test function."
)
if request is None:
raise ValueError("The 'request' fixture must be an argument of the test function.")
raise ValueError(
"The 'request' fixture must be an argument of the test function."
)
if mock is None:
raise ValueError("The 'mock' variable must be an argument of the test function.")
raise ValueError(
"The 'mock' variable must be an argument of the test function."
)
# Run test with a real robot. Skip test if robot connection fails.
if not mock and not request.getfixturevalue("is_robot_available"):
@ -239,11 +253,17 @@ def require_camera(func):
mock = kwargs.get("mock")
if request is None:
raise ValueError("The 'request' fixture must be an argument of the test function.")
raise ValueError(
"The 'request' fixture must be an argument of the test function."
)
if camera_type is None:
raise ValueError("The 'camera_type' must be an argument of the test function.")
raise ValueError(
"The 'camera_type' must be an argument of the test function."
)
if mock is None:
raise ValueError("The 'mock' variable must be an argument of the test function.")
raise ValueError(
"The 'mock' variable must be an argument of the test function."
)
if not mock and not request.getfixturevalue("is_camera_available"):
pytest.skip(f"A {camera_type} camera is not available.")
@ -262,11 +282,17 @@ def require_motor(func):
mock = kwargs.get("mock")
if request is None:
raise ValueError("The 'request' fixture must be an argument of the test function.")
raise ValueError(
"The 'request' fixture must be an argument of the test function."
)
if motor_type is None:
raise ValueError("The 'motor_type' must be an argument of the test function.")
raise ValueError(
"The 'motor_type' must be an argument of the test function."
)
if mock is None:
raise ValueError("The 'mock' variable must be an argument of the test function.")
raise ValueError(
"The 'mock' variable must be an argument of the test function."
)
if not mock and not request.getfixturevalue("is_motor_available"):
pytest.skip(f"A {motor_type} motor is not available.")
@ -285,7 +311,14 @@ def mock_calibration_dir(calibration_dir):
"start_pos": [1442, 843, 2166, 2849, 1988, 1835],
"end_pos": [2440, 1869, -1106, -1848, -926, 3235],
"calib_mode": ["DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "LINEAR"],
"motor_names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
"motor_names": [
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
],
}
Path(str(calibration_dir)).mkdir(parents=True, exist_ok=True)
with open(calibration_dir / "main_follower.json", "w") as f: