[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-03-04 13:38:47 +00:00 committed by AdilZouitine
parent 76df8a31b3
commit 38f5fa4523
79 changed files with 2782 additions and 788 deletions

View File

@ -32,7 +32,11 @@ import numpy as np
import pandas as pd
import PIL
import torch
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
from skimage.metrics import (
mean_squared_error,
peak_signal_noise_ratio,
structural_similarity,
)
from tqdm import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
@ -81,7 +85,9 @@ def get_directory_size(directory: Path) -> int:
return total_size
def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> torch.Tensor:
def load_original_frames(
imgs_dir: Path, timestamps: list[float], fps: int
) -> torch.Tensor:
frames = []
for ts in timestamps:
idx = int(ts * fps)
@ -94,7 +100,11 @@ def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> t
def save_decoded_frames(
imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int
imgs_dir: Path,
save_dir: Path,
frames: torch.Tensor,
timestamps: list[float],
fps: int,
) -> None:
if save_dir.exists() and len(list(save_dir.glob("frame_*.png"))) == len(timestamps):
return
@ -104,7 +114,10 @@ def save_decoded_frames(
idx = int(ts * fps)
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame_{idx:06d}_decoded.png")
shutil.copyfile(imgs_dir / f"frame_{idx:06d}.png", save_dir / f"frame_{idx:06d}_original.png")
shutil.copyfile(
imgs_dir / f"frame_{idx:06d}.png",
save_dir / f"frame_{idx:06d}_original.png",
)
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
@ -116,11 +129,17 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
hf_dataset = dataset.hf_dataset.with_format(None)
# We only save images from the first camera
img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")]
img_keys = [
key for key in hf_dataset.features if key.startswith("observation.image")
]
imgs_dataset = hf_dataset.select_columns(img_keys[0])
for i, item in enumerate(
tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False)
tqdm(
imgs_dataset,
desc=f"saving {dataset.repo_id} first episode images",
leave=False,
)
):
img = item[img_keys[0]]
img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100)
@ -129,7 +148,9 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
break
def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> list[float]:
def sample_timestamps(
timestamps_mode: str, ep_num_images: int, fps: int
) -> list[float]:
# Start at 5 to allow for 2_frames_4_space and 6_frames
idx = random.randint(5, ep_num_images - 1)
match timestamps_mode:
@ -154,7 +175,9 @@ def decode_video_frames(
backend: str,
) -> torch.Tensor:
if backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
return decode_video_frames_torchvision(
video_path, timestamps, tolerance_s, backend
)
else:
raise NotImplementedError(backend)
@ -181,7 +204,9 @@ def benchmark_decoding(
}
with time_benchmark:
frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend)
frames = decode_video_frames(
video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend
)
result["load_time_video_ms"] = time_benchmark.result_ms / num_frames
with time_benchmark:
@ -190,12 +215,18 @@ def benchmark_decoding(
frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
for i in range(num_frames):
result["mse_values"].append(mean_squared_error(original_frames_np[i], frames_np[i]))
result["mse_values"].append(
mean_squared_error(original_frames_np[i], frames_np[i])
)
result["psnr_values"].append(
peak_signal_noise_ratio(original_frames_np[i], frames_np[i], data_range=1.0)
peak_signal_noise_ratio(
original_frames_np[i], frames_np[i], data_range=1.0
)
)
result["ssim_values"].append(
structural_similarity(original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0)
structural_similarity(
original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0
)
)
if save_frames and sample == 0:
@ -215,7 +246,9 @@ def benchmark_decoding(
# As these samples are independent, we run them in parallel threads to speed up the benchmark.
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(process_sample, i) for i in range(num_samples)]
for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False):
for future in tqdm(
as_completed(futures), total=num_samples, desc="samples", leave=False
):
result = future.result()
load_times_video_ms.append(result["load_time_video_ms"])
load_times_images_ms.append(result["load_time_images_ms"])
@ -275,9 +308,13 @@ def benchmark_encoding_decoding(
random.seed(seed)
benchmark_table = []
for timestamps_mode in tqdm(
decoding_cfg["timestamps_modes"], desc="decodings (timestamps_modes)", leave=False
decoding_cfg["timestamps_modes"],
desc="decodings (timestamps_modes)",
leave=False,
):
for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False):
for backend in tqdm(
decoding_cfg["backends"], desc="decodings (backends)", leave=False
):
benchmark_row = benchmark_decoding(
imgs_dir,
video_path,
@ -355,14 +392,23 @@ def main(
imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_")
# We only use the first episode
save_first_episode(imgs_dir, dataset)
for key, values in tqdm(encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False):
for key, values in tqdm(
encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False
):
for value in tqdm(values, desc=f"encodings ({key})", leave=False):
encoding_cfg = BASE_ENCODING.copy()
encoding_cfg["vcodec"] = video_codec
encoding_cfg["pix_fmt"] = pixel_format
encoding_cfg[key] = value
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
args_path = Path(
"_".join(str(value) for value in encoding_cfg.values())
)
video_path = (
output_dir
/ "videos"
/ args_path
/ f"{repo_id.replace('/', '_')}.mp4"
)
benchmark_table += benchmark_encoding_decoding(
dataset,
video_path,
@ -388,7 +434,9 @@ def main(
# Concatenate all results
df_list = [pd.read_csv(csv_path) for csv_path in file_paths]
concatenated_df = pd.concat(df_list, ignore_index=True)
concatenated_path = output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
concatenated_path = (
output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
)
concatenated_df.to_csv(concatenated_path, header=True, index=False)

View File

@ -32,7 +32,10 @@ import torch
from huggingface_hub import HfApi
import lerobot
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
LeRobotDatasetMetadata,
)
# We ported a number of existing datasets ourselves, use this to see the list:
print("List of available datasets:")
@ -40,7 +43,10 @@ pprint(lerobot.available_datasets)
# You can also browse through the datasets created/ported by the community on the hub using the hub api:
hub_api = HfApi()
repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])]
repo_ids = [
info.id
for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])
]
pprint(repo_ids)
# Or simply explore them in your web browser directly at:
@ -55,7 +61,9 @@ ds_meta = LeRobotDatasetMetadata(repo_id)
# structure of the dataset without downloading the actual data yet (only metadata files — which are
# lightweight).
print(f"Total number of episodes: {ds_meta.total_episodes}")
print(f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}")
print(
f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}"
)
print(f"Frames per second used during data collection: {ds_meta.fps}")
print(f"Robot type: {ds_meta.robot_type}")
print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n")

View File

@ -48,10 +48,14 @@ transforms = v2.Compose(
)
# Create another LeRobotDataset with the defined transformations
transformed_dataset = LeRobotDataset(dataset_repo_id, episodes=[0], image_transforms=transforms)
transformed_dataset = LeRobotDataset(
dataset_repo_id, episodes=[0], image_transforms=transforms
)
# Get a frame from the transformed dataset
transformed_frame = transformed_dataset[first_idx][transformed_dataset.meta.camera_keys[0]]
transformed_frame = transformed_dataset[first_idx][
transformed_dataset.meta.camera_keys[0]
]
# Create a directory to store output images
output_dir = Path("outputs/image_transforms")

View File

@ -26,7 +26,10 @@ import math
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
LeRobotDatasetMetadata,
)
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy

View File

@ -0,0 +1,228 @@
import shutil
from pathlib import Path
import numpy as np
import torch
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface."
PUSHT_FEATURES = {
"observation.state": {
"dtype": "float32",
"shape": (2,),
"names": {
"axes": ["x", "y"],
},
},
"action": {
"dtype": "float32",
"shape": (2,),
"names": {
"axes": ["x", "y"],
},
},
"next.reward": {
"dtype": "float32",
"shape": (1,),
"names": None,
},
"next.success": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
"observation.environment_state": {
"dtype": "float32",
"shape": (16,),
"names": [
"keypoints",
],
},
"observation.image": {
"dtype": None,
"shape": (3, 96, 96),
"names": [
"channel",
"height",
"width",
],
},
}
def build_features(mode: str) -> dict:
features = PUSHT_FEATURES
if mode == "keypoints":
features.pop("observation.image")
else:
features.pop("observation.environment_state")
features["observation.image"]["dtype"] = mode
return features
def load_raw_dataset(zarr_path: Path):
try:
from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import (
ReplayBuffer as DiffusionPolicyReplayBuffer,
)
except ModuleNotFoundError as e:
print(
"`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`"
)
raise e
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
return zarr_data
def calculate_coverage(zarr_data):
try:
import pymunk
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
except ModuleNotFoundError as e:
print(
"`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`"
)
raise e
block_pos = zarr_data["state"][:, 2:4]
block_angle = zarr_data["state"][:, 4]
num_frames = len(block_pos)
coverage = np.zeros((num_frames,))
# 8 keypoints with 2 coords each
keypoints = np.zeros((num_frames, 16))
# Set x, y, theta (in radians)
goal_pos_angle = np.array([256, 256, np.pi / 4])
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
for i in range(num_frames):
space = pymunk.Space()
space.gravity = 0, 0
space.damping = 0
# Add walls.
walls = [
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
]
space.add(*walls)
block_body, block_shapes = PushTEnv.add_tee(
space, block_pos[i].tolist(), block_angle[i].item()
)
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
intersection_area = goal_geom.intersection(block_geom).area
goal_area = goal_geom.area
coverage[i] = intersection_area / goal_area
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
return coverage, keypoints
def calculate_success(coverage: float, success_threshold: float):
return coverage > success_threshold
def calculate_reward(coverage: float, success_threshold: float):
return np.clip(coverage / success_threshold, 0, 1)
def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = True):
if mode not in ["video", "image", "keypoints"]:
raise ValueError(mode)
if (LEROBOT_HOME / repo_id).exists():
shutil.rmtree(LEROBOT_HOME / repo_id)
if not raw_dir.exists():
download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")
zarr_data = load_raw_dataset(zarr_path=raw_dir / "pusht_cchi_v7_replay.zarr")
env_state = zarr_data["state"][:]
agent_pos = env_state[:, :2]
action = zarr_data["action"][:]
image = zarr_data["img"] # (b, h, w, c)
episode_data_index = {
"from": np.concatenate(([0], zarr_data.meta["episode_ends"][:-1])),
"to": zarr_data.meta["episode_ends"],
}
# Calculate success and reward based on the overlapping area
# of the T-object and the T-area.
coverage, keypoints = calculate_coverage(zarr_data)
success = calculate_success(coverage, success_threshold=0.95)
reward = calculate_reward(coverage, success_threshold=0.95)
features = build_features(mode)
dataset = LeRobotDataset.create(
repo_id=repo_id,
fps=10,
robot_type="2d pointer",
features=features,
image_writer_threads=4,
)
episodes = range(len(episode_data_index["from"]))
for ep_idx in episodes:
from_idx = episode_data_index["from"][ep_idx]
to_idx = episode_data_index["to"][ep_idx]
num_frames = to_idx - from_idx
for frame_idx in range(num_frames):
i = from_idx + frame_idx
frame = {
"action": torch.from_numpy(action[i]),
# Shift reward and success by +1 until the last item of the episode
"next.reward": reward[i + (frame_idx < num_frames - 1)],
"next.success": success[i + (frame_idx < num_frames - 1)],
}
frame["observation.state"] = torch.from_numpy(agent_pos[i])
if mode == "keypoints":
frame["observation.environment_state"] = torch.from_numpy(keypoints[i])
else:
frame["observation.image"] = torch.from_numpy(image[i])
dataset.add_frame(frame)
dataset.save_episode(task=PUSHT_TASK)
dataset.consolidate()
if push_to_hub:
dataset.push_to_hub()
if __name__ == "__main__":
# To try this script, modify the repo id with your own HuggingFace user (e.g cadene/pusht)
repo_id = "lerobot/pusht"
modes = ["video", "image", "keypoints"]
# Uncomment if you want to try with a specific mode
# modes = ["video"]
# modes = ["image"]
# modes = ["keypoints"]
raw_dir = Path("data/lerobot-raw/pusht_raw")
for mode in modes:
if mode in ["image", "keypoints"]:
repo_id += f"_{mode}"
# download and load raw dataset, create LeRobotDataset, populate it, push to hub
main(raw_dir, repo_id=repo_id, mode=mode)
# Uncomment if you want to load the local dataset and explore it
# dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True)
# breakpoint()

View File

@ -164,7 +164,11 @@ available_real_world_datasets = [
]
available_datasets = sorted(
set(itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets))
set(
itertools.chain(
*available_datasets_per_env.values(), available_real_world_datasets
)
)
)
# lists all available policies from `lerobot/common/policies`
@ -205,9 +209,13 @@ available_policies_per_env = {
"aloha_real": ["act_aloha_real"],
}
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
env_task_pairs = [
(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks
]
env_dataset_pairs = [
(env, dataset) for env, datasets in available_datasets_per_env.items() for dataset in datasets
(env, dataset)
for env, datasets in available_datasets_per_env.items()
for dataset in datasets
]
env_dataset_policy_triplets = [
(env, dataset, policy)

View File

@ -127,7 +127,9 @@ class AsyncImageWriter:
self._stopped = False
if num_threads <= 0 and num_processes <= 0:
raise ValueError("Number of threads and processes must be greater than zero.")
raise ValueError(
"Number of threads and processes must be greater than zero."
)
if self.num_processes == 0:
# Use threading
@ -141,12 +143,16 @@ class AsyncImageWriter:
# Use multiprocessing
self.queue = multiprocessing.JoinableQueue()
for _ in range(self.num_processes):
p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads))
p = multiprocessing.Process(
target=worker_process, args=(self.queue, self.num_threads)
)
p.daemon = True
p.start()
self.processes.append(p)
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
def save_image(
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path
):
if isinstance(image, torch.Tensor):
# Convert tensor to numpy array to minimize main process time
image = image.cpu().numpy()

View File

@ -139,7 +139,9 @@ class LeRobotDatasetMetadata:
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
ep_chunk = self.get_episode_chunk(ep_index)
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
fpath = self.video_path.format(
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index
)
return Path(fpath)
def get_episode_chunk(self, ep_index: int) -> int:
@ -183,7 +185,11 @@ class LeRobotDatasetMetadata:
@property
def camera_keys(self) -> list[str]:
"""Keys to access visual modalities (regardless of their storage method)."""
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
return [
key
for key, ft in self.features.items()
if ft["dtype"] in ["video", "image"]
]
@property
def names(self) -> dict[str, list | dict]:
@ -285,7 +291,9 @@ class LeRobotDatasetMetadata:
"""
for key in self.video_keys:
if not self.features[key].get("info", None):
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
video_path = self.root / self.get_video_file_path(
ep_index=0, vid_key=key
)
self.info["features"][key]["info"] = get_video_info(video_path)
def __repr__(self):
@ -619,7 +627,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
path = str(self.root / "data")
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
else:
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
files = [
str(self.root / self.meta.get_data_file_path(ep_idx))
for ep_idx in self.episodes
]
hf_dataset = load_dataset("parquet", data_files=files, split="train")
# TODO(aliberts): hf_dataset.set_format("torch")
@ -643,12 +654,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property
def num_frames(self) -> int:
"""Number of frames in selected episodes."""
return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames
return (
len(self.hf_dataset)
if self.hf_dataset is not None
else self.meta.total_frames
)
@property
def num_episodes(self) -> int:
"""Number of episodes selected."""
return len(self.episodes) if self.episodes is not None else self.meta.total_episodes
return (
len(self.episodes)
if self.episodes is not None
else self.meta.total_episodes
)
@property
def features(self) -> dict[str, dict]:
@ -662,16 +681,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
else:
return get_hf_features_from_features(self.features)
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
def _get_query_indices(
self, idx: int, ep_idx: int
) -> tuple[dict[str, list[int | bool]]]:
ep_start = self.episode_data_index["from"][ep_idx]
ep_end = self.episode_data_index["to"][ep_idx]
query_indices = {
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
key: [
max(ep_start.item(), min(ep_end.item() - 1, idx + delta))
for delta in delta_idx
]
for key, delta_idx in self.delta_indices.items()
}
padding = { # Pad values outside of current episode range
f"{key}_is_pad": torch.BoolTensor(
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
[
(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item())
for delta in delta_idx
]
)
for key, delta_idx in self.delta_indices.items()
}
@ -771,13 +798,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
return ep_buffer
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
def _get_image_file_path(
self, episode_index: int, image_key: str, frame_index: int
) -> Path:
fpath = DEFAULT_IMAGE_PATH.format(
image_key=image_key, episode_index=episode_index, frame_index=frame_index
)
return self.root / fpath
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
def _save_image(
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path
) -> None:
if self.image_writer is None:
if isinstance(image, torch.Tensor):
image = image.cpu().numpy()
@ -803,7 +834,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Automatically add frame_index and timestamp to episode buffer
frame_index = self.episode_buffer["size"]
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
timestamp = (
frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
)
self.episode_buffer["frame_index"].append(frame_index)
self.episode_buffer["timestamp"].append(timestamp)
@ -821,7 +854,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
if self.features[key]["dtype"] in ["image", "video"]:
img_path = self._get_image_file_path(
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
episode_index=self.episode_buffer["episode_index"],
image_key=key,
frame_index=frame_index,
)
if frame_index == 0:
img_path.parent.mkdir(parents=True, exist_ok=True)
@ -1132,7 +1167,13 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def features(self) -> datasets.Features:
features = {}
for dataset in self._datasets:
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
features.update(
{
k: v
for k, v in dataset.hf_features.items()
if k not in self.disabled_features
}
)
return features
@property
@ -1193,7 +1234,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
continue
break
else:
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
raise AssertionError(
"We expect the loop to break out as long as the index is within bounds."
)
item = self._datasets[dataset_idx][idx - start_idx]
item["dataset_index"] = torch.tensor(dataset_idx)
for data_key in self.disabled_features:

View File

@ -131,7 +131,9 @@ class OnlineBuffer(torch.utils.data.Dataset):
else:
self._delta_timestamps = None
def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]:
def _make_data_spec(
self, data_spec: dict[str, Any], buffer_capacity: int
) -> dict[str, dict[str, Any]]:
"""Makes the data spec for np.memmap."""
if any(k.startswith("_") for k in data_spec):
raise ValueError(
@ -154,14 +156,32 @@ class OnlineBuffer(torch.utils.data.Dataset):
OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()},
# Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied
# with real data rather than the dummy initialization.
OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)},
OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)},
OnlineBuffer.OCCUPANCY_MASK_KEY: {
"dtype": np.dtype("?"),
"shape": (buffer_capacity,),
},
OnlineBuffer.INDEX_KEY: {
"dtype": np.dtype("int64"),
"shape": (buffer_capacity,),
},
OnlineBuffer.FRAME_INDEX_KEY: {
"dtype": np.dtype("int64"),
"shape": (buffer_capacity,),
},
OnlineBuffer.EPISODE_INDEX_KEY: {
"dtype": np.dtype("int64"),
"shape": (buffer_capacity,),
},
OnlineBuffer.TIMESTAMP_KEY: {
"dtype": np.dtype("float64"),
"shape": (buffer_capacity,),
},
}
for k, v in data_spec.items():
complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])}
complete_data_spec[k] = {
"dtype": v["dtype"],
"shape": (buffer_capacity, *v["shape"]),
}
return complete_data_spec
def add_data(self, data: dict[str, np.ndarray]):
@ -188,7 +208,9 @@ class OnlineBuffer(torch.utils.data.Dataset):
# Shift the incoming indices if necessary.
if self.num_frames > 0:
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][
next_index - 1
]
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
data[OnlineBuffer.INDEX_KEY] += last_data_index + 1
@ -223,7 +245,11 @@ class OnlineBuffer(torch.utils.data.Dataset):
@property
def num_episodes(self) -> int:
return len(
np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
np.unique(
self._data[OnlineBuffer.EPISODE_INDEX_KEY][
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]
]
)
)
@property
@ -261,7 +287,9 @@ class OnlineBuffer(torch.utils.data.Dataset):
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY],
)
)[0]
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices]
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][
episode_data_indices
]
for data_key in self.delta_timestamps:
# Note: The logic in this loop is copied from `load_previous_and_future_frames`.
@ -278,7 +306,8 @@ class OnlineBuffer(torch.utils.data.Dataset):
# Check violated query timestamps are all outside the episode range.
assert (
(query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad])
(query_ts[is_pad] < episode_timestamps[0])
| (episode_timestamps[-1] < query_ts[is_pad])
).all(), (
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}"
") inside the episode range."
@ -293,7 +322,9 @@ class OnlineBuffer(torch.utils.data.Dataset):
def get_data_by_key(self, key: str) -> torch.Tensor:
"""Returns all data for a given data key as a Tensor."""
return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
return torch.from_numpy(
self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]
)
def compute_sampler_weights(
@ -324,13 +355,19 @@ def compute_sampler_weights(
- Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not
included here to avoid adding complexity.
"""
if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0):
raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.")
if len(offline_dataset) == 0 and (
online_dataset is None or len(online_dataset) == 0
):
raise ValueError(
"At least one of `offline_dataset` or `online_dataset` should be contain data."
)
if (online_dataset is None) ^ (online_sampling_ratio is None):
raise ValueError(
"`online_dataset` and `online_sampling_ratio` must be provided together or not at all."
)
offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio
offline_sampling_ratio = (
0 if online_sampling_ratio is None else 1 - online_sampling_ratio
)
weights = []

View File

@ -45,7 +45,9 @@ def concatenate_episodes(ep_dicts):
return data_dict
def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4):
def save_images_concurrently(
imgs_array: numpy.array, out_dir: Path, max_workers: int = 4
):
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
@ -55,7 +57,10 @@ def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers
num_images = len(imgs_array)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
[
executor.submit(save_image, imgs_array[i], i, out_dir)
for i in range(num_images)
]
def get_default_encoding() -> dict:
@ -64,7 +69,8 @@ def get_default_encoding() -> dict:
return {
k: v.default
for k, v in signature.parameters.items()
if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"]
if v.default is not inspect.Parameter.empty
and k in ["vcodec", "pix_fmt", "g", "crf"]
}
@ -77,7 +83,9 @@ def check_repo_id(repo_id: str) -> None:
# TODO(aliberts): remove
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
def calculate_episode_data_index(
hf_dataset: datasets.Dataset,
) -> Dict[str, torch.Tensor]:
"""
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.

View File

@ -43,7 +43,10 @@ class EpisodeAwareSampler:
):
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
indices.extend(
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
range(
start_index.item() + drop_n_first_frames,
end_index.item() - drop_n_last_frames,
)
)
self.indices = indices

View File

@ -58,7 +58,9 @@ class RandomSubsetApply(Transform):
elif not isinstance(n_subset, int):
raise TypeError("n_subset should be an int or None")
elif not (1 <= n_subset <= len(transforms)):
raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]")
raise ValueError(
f"n_subset should be in the interval [1, {len(transforms)}]"
)
self.transforms = transforms
total = sum(p)
@ -119,16 +121,22 @@ class SharpnessJitter(Transform):
def _check_input(self, sharpness):
if isinstance(sharpness, (int, float)):
if sharpness < 0:
raise ValueError("If sharpness is a single number, it must be non negative.")
raise ValueError(
"If sharpness is a single number, it must be non negative."
)
sharpness = [1.0 - sharpness, 1.0 + sharpness]
sharpness[0] = max(sharpness[0], 0.0)
elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2:
sharpness = [float(v) for v in sharpness]
else:
raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.")
raise TypeError(
f"{sharpness=} should be a single number or a sequence with length 2."
)
if not 0.0 <= sharpness[0] <= sharpness[1]:
raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.")
raise ValueError(
f"sharpnesss values should be between (0., inf), but got {sharpness}."
)
return float(sharpness[0]), float(sharpness[1])

View File

@ -52,9 +52,15 @@ STATS_PATH = "meta/stats.json"
EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
TASKS_PATH = "meta/tasks.jsonl"
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
DEFAULT_VIDEO_PATH = (
"videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
)
DEFAULT_PARQUET_PATH = (
"data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
)
DEFAULT_IMAGE_PATH = (
"images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
)
DATASET_CARD_TEMPLATE = """
---
@ -540,7 +546,10 @@ def check_timestamps_sync(
def check_delta_timestamps(
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
delta_timestamps: dict[str, list[float]],
fps: int,
tolerance_s: float,
raise_value_error: bool = True,
) -> bool:
"""This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance.
This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be
@ -548,10 +557,14 @@ def check_delta_timestamps(
"""
outside_tolerance = {}
for key, delta_ts in delta_timestamps.items():
within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts]
within_tolerance = [
abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts
]
if not all(within_tolerance):
outside_tolerance[key] = [
ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within
ts
for ts, is_within in zip(delta_ts, within_tolerance, strict=True)
if not is_within
]
if len(outside_tolerance) > 0:
@ -569,7 +582,9 @@ def check_delta_timestamps(
return True
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
def get_delta_indices(
delta_timestamps: dict[str, list[float]], fps: int
) -> dict[str, list[int]]:
delta_indices = {}
for key, delta_ts in delta_timestamps.items():
delta_indices[key] = [round(d * fps) for d in delta_ts]
@ -634,7 +649,9 @@ def create_lerobot_dataset_card(
],
)
card_template = (importlib.resources.files("lerobot.common.datasets") / "card_template.md").read_text()
card_template = (
importlib.resources.files("lerobot.common.datasets") / "card_template.md"
).read_text()
return DatasetCard.from_template(
card_data=card_data,

View File

@ -118,7 +118,10 @@ DATASETS = {
"single_task": "Place the battery into the slot of the remote controller.",
**ALOHA_STATIC_INFO,
},
"aloha_static_candy": {"single_task": "Pick up the candy and unwrap it.", **ALOHA_STATIC_INFO},
"aloha_static_candy": {
"single_task": "Pick up the candy and unwrap it.",
**ALOHA_STATIC_INFO,
},
"aloha_static_coffee": {
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.",
**ALOHA_STATIC_INFO,
@ -167,13 +170,22 @@ DATASETS = {
"single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
**ALOHA_STATIC_INFO,
},
"aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO},
"aloha_sim_insertion_scripted": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
"aloha_static_ziploc_slide": {
"single_task": "Slide open the ziploc bag.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_scripted": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_scripted_image": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_human": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
"aloha_sim_insertion_human": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_human_image": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
@ -194,10 +206,19 @@ DATASETS = {
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
**ALOHA_STATIC_INFO,
},
"pusht": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
"pusht_image": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
"pusht": {
"single_task": "Push the T-shaped block onto the T-shaped target.",
**PUSHT_INFO,
},
"pusht_image": {
"single_task": "Push the T-shaped block onto the T-shaped target.",
**PUSHT_INFO,
},
"unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO},
"unitreeh1_rearrange_objects": {"single_task": "Put the object into the bin.", **UNITREEH_INFO},
"unitreeh1_rearrange_objects": {
"single_task": "Put the object into the bin.",
**UNITREEH_INFO,
},
"unitreeh1_two_robot_greeting": {
"single_task": "Greet the other robot with a high five.",
**UNITREEH_INFO,
@ -207,13 +228,31 @@ DATASETS = {
**UNITREEH_INFO,
},
"xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_replay": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_replay_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_image": {
"single_task": "Pick up the cube and lift it.",
**XARM_INFO,
},
"xarm_lift_medium_replay": {
"single_task": "Pick up the cube and lift it.",
**XARM_INFO,
},
"xarm_lift_medium_replay_image": {
"single_task": "Pick up the cube and lift it.",
**XARM_INFO,
},
"xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_replay": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_replay_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_image": {
"single_task": "Push the cube onto the target.",
**XARM_INFO,
},
"xarm_push_medium_replay": {
"single_task": "Push the cube onto the target.",
**XARM_INFO,
},
"xarm_push_medium_replay_image": {
"single_task": "Push the cube onto the target.",
**XARM_INFO,
},
"umi_cup_in_the_wild": {
"single_task": "Put the cup on the plate.",
"license": "apache-2.0",

View File

@ -218,7 +218,9 @@ def get_features_from_hf_dataset(
dtype = ft.feature.dtype
shape = (ft.length,)
motor_names = (
robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
robot_config["names"][key]
if robot_config
else [f"motor_{i}" for i in range(ft.length)]
)
assert len(motor_names) == shape[0]
names = {"motors": motor_names}
@ -242,11 +244,15 @@ def get_features_from_hf_dataset(
return features
def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
def add_task_index_by_episodes(
dataset: Dataset, tasks_by_episodes: dict
) -> tuple[Dataset, list[str]]:
df = dataset.to_pandas()
tasks = list(set(tasks_by_episodes.values()))
tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)}
episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()}
episodes_to_task_index = {
ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()
}
df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int)
features = dataset.features
@ -263,10 +269,19 @@ def add_task_index_from_tasks_col(
# HACK: This is to clean some of the instructions in our version of Open X datasets
prefix_to_clean = "tf.Tensor(b'"
suffix_to_clean = "', shape=(), dtype=string)"
df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean)
df[tasks_col] = (
df[tasks_col]
.str.removeprefix(prefix_to_clean)
.str.removesuffix(suffix_to_clean)
)
# Create task_index col
tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict()
tasks_by_episode = (
df.groupby("episode_index")[tasks_col]
.unique()
.apply(lambda x: x.tolist())
.to_dict()
)
tasks = df[tasks_col].unique().tolist()
tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)}
df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int)
@ -291,7 +306,9 @@ def split_parquet_by_episodes(
for ep_chunk in range(total_chunks):
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(
episode_chunk=ep_chunk
)
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
for ep_idx in range(ep_chunk_start, ep_chunk_end):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
@ -323,7 +340,9 @@ def move_videos(
videos_moved = False
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
if len(video_files) == 0:
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")]
video_files = [
str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")
]
videos_moved = True # Videos have already been moved
assert len(video_files) == total_episodes * len(video_keys)
@ -354,7 +373,9 @@ def move_videos(
target_path = DEFAULT_VIDEO_PATH.format(
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
)
video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
video_file = V1_VIDEO_FILE.format(
video_key=vid_key, episode_index=ep_idx
)
if len(video_dirs) == 1:
video_path = video_dirs[0] / video_file
else:
@ -371,7 +392,9 @@ def move_videos(
subprocess.run(["git", "push"], cwd=work_dir, check=True)
def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None:
def fix_lfs_video_files_tracking(
work_dir: Path, lfs_untracked_videos: list[str]
) -> None:
"""
HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
there's no other option than to download the actual files and reupload them with lfs tracking.
@ -379,7 +402,12 @@ def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]
for i in range(0, len(lfs_untracked_videos), 100):
files = lfs_untracked_videos[i : i + 100]
try:
subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True)
subprocess.run(
["git", "rm", "--cached", *files],
cwd=work_dir,
capture_output=True,
check=True,
)
except subprocess.CalledProcessError as e:
print("git rm --cached ERROR:")
print(e.stderr)
@ -390,10 +418,14 @@ def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]
subprocess.run(["git", "push"], cwd=work_dir, check=True)
def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None:
def fix_gitattributes(
work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path
) -> None:
shutil.copyfile(clean_gittatributes, current_gittatributes)
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True)
subprocess.run(
["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True
)
subprocess.run(["git", "push"], cwd=work_dir, check=True)
@ -402,7 +434,17 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
repo_url = f"https://huggingface.co/datasets/{repo_id}"
env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
subprocess.run(
["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)],
[
"git",
"clone",
"--branch",
branch,
"--single-branch",
"--depth",
"1",
repo_url,
str(work_dir),
],
check=True,
env=env,
)
@ -410,13 +452,19 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]:
lfs_tracked_files = subprocess.run(
["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
["git", "lfs", "ls-files", "-n"],
cwd=work_dir,
capture_output=True,
text=True,
check=True,
)
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
return [f for f in video_files if f not in lfs_tracked_files]
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
def get_videos_info(
repo_id: str, local_dir: Path, video_keys: list[str], branch: str
) -> dict:
# Assumes first episode
video_files = [
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
@ -424,7 +472,11 @@ def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch
]
hub_api = HfApi()
hub_api.snapshot_download(
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
repo_id=repo_id,
repo_type="dataset",
local_dir=local_dir,
revision=branch,
allow_patterns=video_files,
)
videos_info_dict = {}
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
@ -451,7 +503,11 @@ def convert_dataset(
hub_api = HfApi()
hub_api.snapshot_download(
repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/"
repo_id=repo_id,
repo_type="dataset",
revision=v1,
local_dir=v1x_dir,
ignore_patterns="videos*/",
)
branch = "main"
if test_branch:
@ -483,19 +539,31 @@ def convert_dataset(
if single_task:
tasks_by_episodes = {ep_idx: single_task for ep_idx in episode_indices}
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
tasks_by_episodes = {
ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()
}
elif tasks_path:
tasks_by_episodes = load_json(tasks_path)
tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
tasks_by_episodes = {
int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()
}
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
tasks_by_episodes = {
ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()
}
elif tasks_col:
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col)
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(
dataset, tasks_col
)
else:
raise ValueError
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
assert set(tasks) == {
task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks
}
tasks = [
{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)
]
write_jsonlines(tasks, v20_dir / TASKS_PATH)
features["task_index"] = {
"dtype": "int64",
@ -509,14 +577,25 @@ def convert_dataset(
dataset = dataset.remove_columns(video_keys)
clean_gitattr = Path(
hub_api.hf_hub_download(
repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
repo_id=GITATTRIBUTES_REF,
repo_type="dataset",
local_dir=local_dir,
filename=".gitattributes",
)
).absolute()
with tempfile.TemporaryDirectory() as tmp_video_dir:
move_videos(
repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch
repo_id,
video_keys,
total_episodes,
total_chunks,
Path(tmp_video_dir),
clean_gitattr,
branch,
)
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
videos_info = get_videos_info(
repo_id, v1x_dir, video_keys=video_keys, branch=branch
)
for key in video_keys:
features[key]["shape"] = (
videos_info[key].pop("video.height"),
@ -524,15 +603,22 @@ def convert_dataset(
videos_info[key].pop("video.channels"),
)
features[key]["video_info"] = videos_info[key]
assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
assert math.isclose(
videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3
)
if "encoding" in metadata_v1:
assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
assert (
videos_info[key]["video.pix_fmt"]
== metadata_v1["encoding"]["pix_fmt"]
)
else:
assert metadata_v1.get("video", 0) == 0
videos_info = None
# Split data into 1 parquet file by episode
episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
episode_lengths = split_parquet_by_episodes(
dataset, total_episodes, total_chunks, v20_dir
)
if robot_config is not None:
robot_type = robot_config.type
@ -543,7 +629,11 @@ def convert_dataset(
# Episodes
episodes = [
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
{
"episode_index": ep_idx,
"tasks": tasks_by_episodes[ep_idx],
"length": episode_lengths[ep_idx],
}
for ep_idx in episode_indices
]
write_jsonlines(episodes, v20_dir / EPISODES_PATH)
@ -566,16 +656,27 @@ def convert_dataset(
}
write_json(metadata_v2_0, v20_dir / INFO_PATH)
convert_stats_to_json(v1x_dir, v20_dir)
card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs)
card = create_lerobot_dataset_card(
tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs
)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
hub_api.delete_folder(
repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch
)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
hub_api.delete_folder(
repo_id=repo_id,
path_in_repo="meta_data",
repo_type="dataset",
revision=branch,
)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
hub_api.delete_folder(
repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch
)
hub_api.upload_folder(
repo_id=repo_id,

View File

@ -344,7 +344,9 @@ def get_audio_info(video_path: Path | str) -> dict:
"json",
str(video_path),
]
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
result = subprocess.run(
ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
@ -358,7 +360,9 @@ def get_audio_info(video_path: Path | str) -> dict:
"has_audio": True,
"audio.channels": audio_stream_info.get("channels", None),
"audio.codec": audio_stream_info.get("codec_name", None),
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
"audio.bit_rate": int(audio_stream_info["bit_rate"])
if audio_stream_info.get("bit_rate")
else None,
"audio.sample_rate": int(audio_stream_info["sample_rate"])
if audio_stream_info.get("sample_rate")
else None,
@ -380,7 +384,9 @@ def get_video_info(video_path: Path | str) -> dict:
"json",
str(video_path),
]
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
result = subprocess.run(
ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}")

View File

@ -70,7 +70,9 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
return env
def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
def make_maniskill_env(
cfg: DictConfig, n_envs: int | None = None
) -> gym.vector.VectorEnv | None:
"""Make ManiSkill3 gym environment"""
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
@ -87,7 +89,9 @@ def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector
# state should have the size of 25
# env = ConvertToLeRobotEnv(env, n_envs)
# env = PixelWrapper(cfg, env, n_envs)
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
env._max_episode_steps = env.max_episode_steps = (
50 # gym_utils.find_max_episode_steps_value(env)
)
env.unwrapped.metadata["render_fps"] = 20
return env
@ -114,7 +118,11 @@ class PixelWrapper(gym.Wrapper):
def _get_obs(self, obs):
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
self._frames.append(frame)
return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)}
return {
"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(
self.env.device
)
}
def reset(self, seed):
obs, info = self.env.reset() # (seed=seed)
@ -148,7 +156,9 @@ class ConvertToLeRobotEnv(gym.Wrapper):
images = torch.concat(images, axis=-1)
# flatten the rest of the data which should just be state data
observation = common.flatten_state_dict(observation, use_torch=True, device=self.base_env.device)
observation = common.flatten_state_dict(
observation, use_torch=True, device=self.base_env.device
)
ret = dict()
ret["state"] = observation
ret["pixels"] = images

View File

@ -84,7 +84,9 @@ class Logger:
pretrained_model_dir_name = "pretrained_model"
training_state_file_name = "training_state.pth"
def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None):
def __init__(
self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None
):
"""
Args:
log_dir: The directory to save all logs and training outputs to.
@ -104,7 +106,9 @@ class Logger:
enable_wandb = cfg.get("wandb", {}).get("enable", False)
run_offline = not enable_wandb or not project
if run_offline:
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
logging.info(
colored("Logs will be saved locally.", "yellow", attrs=["bold"])
)
self._wandb = None
else:
os.environ["WANDB_SILENT"] = "true"
@ -130,7 +134,9 @@ class Logger:
# Handle custom step key for rl asynchronous training.
self._wandb_custom_step_key: set[str] | None = None
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
logging.info(
f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}"
)
self._wandb = wandb
@classmethod
@ -151,7 +157,9 @@ class Logger:
"""
return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name
def save_model(self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None):
def save_model(
self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None
):
"""Save the weights of the Policy model using PyTorchModelHubMixin.
The weights are saved in a folder called "pretrained_model" under the checkpoint directory.
@ -221,22 +229,30 @@ class Logger:
else f"{self._group.replace(':', '_').replace('/', '_')}-{self._cfg.seed}-{identifier}"
)
self.save_model(
checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name
checkpoint_dir / self.pretrained_model_dir_name,
policy,
wandb_artifact_name=wandb_artifact_name,
)
self.save_training_state(
checkpoint_dir, train_step, optimizer, scheduler, interaction_step
)
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler, interaction_step)
os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir)
def load_last_training_state(self, optimizer: Optimizer | dict, scheduler: LRScheduler | None) -> int:
def load_last_training_state(
self, optimizer: Optimizer | dict, scheduler: LRScheduler | None
) -> int:
"""
Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and
random state, and return the global training step.
"""
training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name)
training_state = torch.load(
self.last_checkpoint_dir / self.training_state_file_name
)
# For the case where the optimizer is a dictionary of optimizers (e.g., sac)
if type(training_state["optimizer"]) is dict:
assert set(training_state["optimizer"].keys()) == set(optimizer.keys()), (
"Optimizer dictionaries do not have the same keys during resume!"
)
assert set(training_state["optimizer"].keys()) == set(
optimizer.keys()
), "Optimizer dictionaries do not have the same keys during resume!"
for k, v in training_state["optimizer"].items():
optimizer[k].load_state_dict(v)
else:
@ -248,10 +264,18 @@ class Logger:
"The checkpoint contains a scheduler state_dict, but no LRScheduler was provided."
)
# Small hack to get the expected keys: use `get_global_random_state`.
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
set_global_random_state(
{k: training_state[k] for k in get_global_random_state()}
)
return training_state["step"]
def log_dict(self, d, step: int | None = None, mode="train", custom_step_key: str | None = None):
def log_dict(
self,
d,
step: int | None = None,
mode="train",
custom_step_key: str | None = None,
):
"""Log a dictionary of metrics to WandB."""
assert mode in {"train", "eval"}
# TODO(alexander-soare): Add local text log.
@ -280,12 +304,20 @@ class Logger:
continue
# Do not log the custom step key itself.
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
if (
self._wandb_custom_step_key is not None
and k in self._wandb_custom_step_key
):
continue
if custom_step_key is not None:
value_custom_step = d[custom_step_key]
self._wandb.log({f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step})
self._wandb.log(
{
f"{mode}/{k}": v,
f"{mode}/{custom_step_key}": value_custom_step,
}
)
continue
self._wandb.log(data={f"{mode}/{k}": v}, step=step)

View File

@ -74,7 +74,9 @@ class ACTPolicy(PreTrainedPolicy):
self.model = ACT(config)
if config.temporal_ensemble_coeff is not None:
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
self.temporal_ensembler = ACTTemporalEnsembler(
config.temporal_ensemble_coeff, config.chunk_size
)
self.reset()
@ -153,7 +155,8 @@ class ACTPolicy(PreTrainedPolicy):
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
F.l1_loss(batch["action"], actions_hat, reduction="none")
* ~batch["action_is_pad"].unsqueeze(-1)
).mean()
loss_dict = {"l1_loss": l1_loss.item()}
@ -163,7 +166,12 @@ class ACTPolicy(PreTrainedPolicy):
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
(
-0.5
* (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())
)
.sum(-1)
.mean()
)
loss_dict["kld_loss"] = mean_kld.item()
loss = l1_loss + mean_kld * self.config.kl_weight
@ -217,7 +225,9 @@ class ACTTemporalEnsembler:
```
"""
self.chunk_size = chunk_size
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
self.ensemble_weights = torch.exp(
-temporal_ensemble_coeff * torch.arange(chunk_size)
)
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
self.reset()
@ -233,7 +243,9 @@ class ACTTemporalEnsembler:
time steps, and pop/return the next batch of actions in the sequence.
"""
self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device)
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(
device=actions.device
)
if self.ensembled_actions is None:
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
# time step of the episode.
@ -241,19 +253,34 @@ class ACTTemporalEnsembler:
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
# operations later.
self.ensembled_actions_count = torch.ones(
(self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
(self.chunk_size, 1),
dtype=torch.long,
device=self.ensembled_actions.device,
)
else:
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
# the online update for those entries.
self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]
self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]
self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size)
self.ensembled_actions *= self.ensemble_weights_cumsum[
self.ensembled_actions_count - 1
]
self.ensembled_actions += (
actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
)
self.ensembled_actions /= self.ensemble_weights_cumsum[
self.ensembled_actions_count
]
self.ensembled_actions_count = torch.clamp(
self.ensembled_actions_count + 1, max=self.chunk_size
)
# The last action, which has no prior online average, needs to get concatenated onto the end.
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
self.ensembled_actions = torch.cat(
[self.ensembled_actions, actions[:, -1:]], dim=1
)
self.ensembled_actions_count = torch.cat(
[self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])]
[
self.ensembled_actions_count,
torch.ones_like(self.ensembled_actions_count[-1:]),
]
)
# "Consume" the first action.
action, self.ensembled_actions, self.ensembled_actions_count = (
@ -319,7 +346,9 @@ class ACT(nn.Module):
config.dim_model,
)
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
self.vae_encoder_latent_output_proj = nn.Linear(
config.dim_model, config.latent_dim * 2
)
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
# dimension.
num_input_token_encoder = 1 + config.chunk_size
@ -327,20 +356,28 @@ class ACT(nn.Module):
num_input_token_encoder += 1
self.register_buffer(
"vae_encoder_pos_enc",
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
create_sinusoidal_pos_embedding(
num_input_token_encoder, config.dim_model
).unsqueeze(0),
)
# Backbone for image feature extraction.
if self.config.image_features:
backbone_model = getattr(torchvision.models, config.vision_backbone)(
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
replace_stride_with_dilation=[
False,
False,
config.replace_final_stride_with_dilation,
],
weights=config.pretrained_backbone_weights,
norm_layer=FrozenBatchNorm2d,
)
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
# feature map).
# Note: The forward method of this returns a dict: {"feature_map": output}.
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
self.backbone = IntermediateLayerGetter(
backbone_model, return_layers={"layer4": "feature_map"}
)
# Transformer (acts as VAE decoder when training with the variational objective).
self.encoder = ACTEncoder(config)
@ -386,7 +423,9 @@ class ACT(nn.Module):
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
def forward(
self, batch: dict[str, Tensor]
) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
`batch` should have the following structure:
@ -424,7 +463,9 @@ class ACT(nn.Module):
if self.config.robot_state_feature:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
action_embed = self.vae_encoder_action_input_proj(
batch["action"]
) # (B, S, D)
if self.config.robot_state_feature:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
@ -465,20 +506,24 @@ class ACT(nn.Module):
# When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
batch["observation.state"].device
)
latent_sample = torch.zeros(
[batch_size, self.config.latent_dim], dtype=torch.float32
).to(batch["observation.state"].device)
# Prepare transformer encoder inputs.
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
encoder_in_pos_embed = list(
self.encoder_1d_feature_pos_embed.weight.unsqueeze(1)
)
# Robot state token.
if self.config.robot_state_feature:
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
# Environment state token.
if self.config.env_state_feature:
encoder_in_tokens.append(
self.encoder_env_state_input_proj(batch["observation.environment_state"])
self.encoder_env_state_input_proj(
batch["observation.environment_state"]
)
)
# Camera observation features and positional embeddings.
@ -535,12 +580,21 @@ class ACTEncoder(nn.Module):
def __init__(self, config: ACTConfig, is_vae_encoder: bool = False):
super().__init__()
self.is_vae_encoder = is_vae_encoder
num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)])
num_layers = (
config.n_vae_encoder_layers
if self.is_vae_encoder
else config.n_encoder_layers
)
self.layers = nn.ModuleList(
[ACTEncoderLayer(config) for _ in range(num_layers)]
)
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
def forward(
self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
self,
x: Tensor,
pos_embed: Tensor | None = None,
key_padding_mask: Tensor | None = None,
) -> Tensor:
for layer in self.layers:
x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
@ -551,7 +605,9 @@ class ACTEncoder(nn.Module):
class ACTEncoderLayer(nn.Module):
def __init__(self, config: ACTConfig):
super().__init__()
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
self.self_attn = nn.MultiheadAttention(
config.dim_model, config.n_heads, dropout=config.dropout
)
# Feed forward layers.
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
@ -566,7 +622,9 @@ class ACTEncoderLayer(nn.Module):
self.activation = get_activation_fn(config.feedforward_activation)
self.pre_norm = config.pre_norm
def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor:
def forward(
self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
) -> Tensor:
skip = x
if self.pre_norm:
x = self.norm1(x)
@ -591,7 +649,9 @@ class ACTDecoder(nn.Module):
def __init__(self, config: ACTConfig):
"""Convenience module for running multiple decoder layers followed by normalization."""
super().__init__()
self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)])
self.layers = nn.ModuleList(
[ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]
)
self.norm = nn.LayerNorm(config.dim_model)
def forward(
@ -603,7 +663,10 @@ class ACTDecoder(nn.Module):
) -> Tensor:
for layer in self.layers:
x = layer(
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
x,
encoder_out,
decoder_pos_embed=decoder_pos_embed,
encoder_pos_embed=encoder_pos_embed,
)
if self.norm is not None:
x = self.norm(x)
@ -613,8 +676,12 @@ class ACTDecoder(nn.Module):
class ACTDecoderLayer(nn.Module):
def __init__(self, config: ACTConfig):
super().__init__()
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
self.self_attn = nn.MultiheadAttention(
config.dim_model, config.n_heads, dropout=config.dropout
)
self.multihead_attn = nn.MultiheadAttention(
config.dim_model, config.n_heads, dropout=config.dropout
)
# Feed forward layers.
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
@ -655,7 +722,9 @@ class ACTDecoderLayer(nn.Module):
if self.pre_norm:
x = self.norm1(x)
q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
x = self.self_attn(q, k, value=x)[
0
] # select just the output, not the attention weights
x = skip + self.dropout1(x)
if self.pre_norm:
skip = x
@ -692,9 +761,14 @@ def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tenso
"""
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)]
return [
position / np.power(10000, 2 * (hid_j // 2) / dimension)
for hid_j in range(dimension)
]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)])
sinusoid_table = np.array(
[get_position_angle_vec(pos_i) for pos_i in range(num_positions)]
)
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.from_numpy(sinusoid_table).float()
@ -739,7 +813,9 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
inverse_frequency = self._temperature ** (
2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
2
* (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2)
/ self.dimension
)
x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
@ -747,9 +823,15 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
# Note: this stack then flatten operation results in interleaved sine and cosine terms.
# pos_embed_x and pos_embed_y are (1, H, W, C // 2).
pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3)
pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3)
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W)
pos_embed_x = torch.stack(
(x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1
).flatten(3)
pos_embed_y = torch.stack(
(y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1
).flatten(3)
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(
0, 3, 1, 2
) # (1, C, H, W)
return pos_embed

View File

@ -132,7 +132,11 @@ class DiffusionPolicy(PreTrainedPolicy):
if len(self._queues["action"]) == 0:
# stack n latest observations from the queue
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
batch = {
k: torch.stack(list(self._queues[k]), dim=1)
for k in batch
if k in self._queues
}
actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary?
@ -189,7 +193,9 @@ class DiffusionModel(nn.Module):
if self.config.env_state_feature:
global_cond_dim += self.config.env_state_feature.shape[0]
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
self.unet = DiffusionConditionalUnet1d(
config, global_cond_dim=global_cond_dim * config.n_obs_steps
)
self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type,
@ -209,7 +215,10 @@ class DiffusionModel(nn.Module):
# ========= inference ============
def conditional_sample(
self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None
self,
batch_size: int,
global_cond: Tensor | None = None,
generator: torch.Generator | None = None,
) -> Tensor:
device = get_device_from_parameters(self)
dtype = get_dtype_from_parameters(self)
@ -232,7 +241,9 @@ class DiffusionModel(nn.Module):
global_cond=global_cond,
)
# Compute previous image: x_t -> x_t-1
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
sample = self.noise_scheduler.step(
model_output, t, sample, generator=generator
).prev_sample
return sample
@ -244,27 +255,39 @@ class DiffusionModel(nn.Module):
if self.config.image_features:
if self.config.use_separate_rgb_encoder_per_camera:
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
images_per_camera = einops.rearrange(
batch["observation.images"], "b s n ... -> n (b s) ..."
)
img_features_list = torch.cat(
[
encoder(images)
for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True)
for encoder, images in zip(
self.rgb_encoder, images_per_camera, strict=True
)
]
)
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features).
img_features = einops.rearrange(
img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
img_features_list,
"(n b s) ... -> b s (n ...)",
b=batch_size,
s=n_obs_steps,
)
else:
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
img_features = self.rgb_encoder(
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
einops.rearrange(
batch["observation.images"], "b s n ... -> (b s n) ..."
)
)
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features).
img_features = einops.rearrange(
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
img_features,
"(b s n) ... -> b s (n ...)",
b=batch_size,
s=n_obs_steps,
)
global_cond_feats.append(img_features)
@ -350,7 +373,9 @@ class DiffusionModel(nn.Module):
elif self.config.prediction_type == "sample":
target = batch["action"]
else:
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
raise ValueError(
f"Unsupported prediction type {self.config.prediction_type}"
)
loss = F.mse_loss(pred, target, reduction="none")
@ -410,7 +435,9 @@ class SpatialSoftmax(nn.Module):
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and causes a small degradation in pc_success of pre-trained models.
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
pos_x, pos_y = np.meshgrid(
np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)
)
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
# register as buffer so it's moved to the correct device.
@ -452,7 +479,9 @@ class DiffusionRgbEncoder(nn.Module):
# Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
self.maybe_random_crop = torchvision.transforms.RandomCrop(
config.crop_shape
)
else:
self.maybe_random_crop = self.center_crop
else:
@ -473,7 +502,9 @@ class DiffusionRgbEncoder(nn.Module):
self.backbone = _replace_submodules(
root_module=self.backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features // 16, num_channels=x.num_features
),
)
# Set up pooling and final layers.
@ -515,7 +546,9 @@ class DiffusionRgbEncoder(nn.Module):
def _replace_submodules(
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
root_module: nn.Module,
predicate: Callable[[nn.Module], bool],
func: Callable[[nn.Module], nn.Module],
) -> nn.Module:
"""
Args:
@ -528,7 +561,11 @@ def _replace_submodules(
if predicate(root_module):
return func(root_module)
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
replace_list = [
k.split(".")
for k, m in root_module.named_modules(remove_duplicate=True)
if predicate(m)
]
for *parents, k in replace_list:
parent_module = root_module
if len(parents) > 0:
@ -543,7 +580,9 @@ def _replace_submodules(
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
assert not any(
predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)
)
return root_module
@ -571,7 +610,9 @@ class DiffusionConv1dBlock(nn.Module):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
nn.Conv1d(
inp_channels, out_channels, kernel_size, padding=kernel_size // 2
),
nn.GroupNorm(n_groups, out_channels),
nn.Mish(),
)
@ -594,9 +635,13 @@ class DiffusionConditionalUnet1d(nn.Module):
# Encoder for the diffusion timestep.
self.diffusion_step_encoder = nn.Sequential(
DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim),
nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4),
nn.Linear(
config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4
),
nn.Mish(),
nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim),
nn.Linear(
config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim
),
)
# The FiLM conditioning dimension.
@ -621,10 +666,16 @@ class DiffusionConditionalUnet1d(nn.Module):
self.down_modules.append(
nn.ModuleList(
[
DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(
dim_in, dim_out, **common_res_block_kwargs
),
DiffusionConditionalResidualBlock1d(
dim_out, dim_out, **common_res_block_kwargs
),
# Downsample as long as it is not the last block.
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
nn.Conv1d(dim_out, dim_out, 3, 2, 1)
if not is_last
else nn.Identity(),
]
)
)
@ -633,10 +684,14 @@ class DiffusionConditionalUnet1d(nn.Module):
self.mid_modules = nn.ModuleList(
[
DiffusionConditionalResidualBlock1d(
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
config.down_dims[-1],
config.down_dims[-1],
**common_res_block_kwargs,
),
DiffusionConditionalResidualBlock1d(
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
config.down_dims[-1],
config.down_dims[-1],
**common_res_block_kwargs,
),
]
)
@ -649,10 +704,16 @@ class DiffusionConditionalUnet1d(nn.Module):
nn.ModuleList(
[
# dim_in * 2, because it takes the encoder's skip connection as well
DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(
dim_in * 2, dim_out, **common_res_block_kwargs
),
DiffusionConditionalResidualBlock1d(
dim_out, dim_out, **common_res_block_kwargs
),
# Upsample as long as it is not the last block.
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1)
if not is_last
else nn.Identity(),
]
)
)
@ -726,17 +787,23 @@ class DiffusionConditionalResidualBlock1d(nn.Module):
self.use_film_scale_modulation = use_film_scale_modulation
self.out_channels = out_channels
self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
self.conv1 = DiffusionConv1dBlock(
in_channels, out_channels, kernel_size, n_groups=n_groups
)
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
self.conv2 = DiffusionConv1dBlock(
out_channels, out_channels, kernel_size, n_groups=n_groups
)
# A final convolution for dimension matching the residual (if needed).
self.residual_conv = (
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
nn.Conv1d(in_channels, out_channels, 1)
if in_channels != out_channels
else nn.Identity()
)
def forward(self, x: Tensor, cond: Tensor) -> Tensor:

View File

@ -7,7 +7,9 @@ from torch import Tensor, nn
from .configuration_classifier import ClassifierConfig
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
@ -15,7 +17,10 @@ class ClassifierOutput:
"""Wrapper for classifier outputs with additional metadata."""
def __init__(
self, logits: Tensor, probabilities: Optional[Tensor] = None, hidden_states: Optional[Tensor] = None
self,
logits: Tensor,
probabilities: Optional[Tensor] = None,
hidden_states: Optional[Tensor] = None,
):
self.logits = logits
self.probabilities = probabilities
@ -43,12 +48,14 @@ class Classifier(
name = "classifier"
def __init__(self, config: ClassifierConfig):
from transformers import AutoImageProcessor, AutoModel
from transformers import AutoModel
super().__init__()
self.config = config
# self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
encoder = AutoModel.from_pretrained(
self.config.model_name, trust_remote_code=True
)
# Extract vision model if we're given a multimodal model
if hasattr(encoder, "vision_model"):
logging.info("Multimodal model detected - using vision encoder only")
@ -74,7 +81,9 @@ class Classifier(
self.feature_dim = self.encoder.fc.in_features
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
elif hasattr(self.encoder.config, "hidden_sizes"):
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
self.feature_dim = self.encoder.config.hidden_sizes[
-1
] # Last channel dimension
else:
raise ValueError("Unsupported CNN architecture")
@ -94,14 +103,19 @@ class Classifier(
if hasattr(self.encoder.config, "hidden_size"):
input_dim = self.encoder.config.hidden_size
else:
raise ValueError("Unsupported transformer architecture since hidden_size is not found")
raise ValueError(
"Unsupported transformer architecture since hidden_size is not found"
)
self.classifier_head = nn.Sequential(
nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
nn.Dropout(self.config.dropout_rate),
nn.LayerNorm(self.config.hidden_dim),
nn.ReLU(),
nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes),
nn.Linear(
self.config.hidden_dim,
1 if self.config.num_classes == 2 else self.config.num_classes,
),
)
self.classifier_head = self.classifier_head.to(self.config.device)
@ -127,7 +141,10 @@ class Classifier(
return features
else: # Transformer models
outputs = self.encoder(processed)
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
if (
hasattr(outputs, "pooler_output")
and outputs.pooler_output is not None
):
return outputs.pooler_output
return outputs.last_hidden_state[:, 0, :]
@ -143,7 +160,9 @@ class Classifier(
else:
probabilities = torch.softmax(logits, dim=-1)
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
return ClassifierOutput(
logits=logits, probabilities=probabilities, hidden_states=encoder_outputs
)
def predict_reward(self, x, threshold=0.6):
if self.config.num_classes == 2:

View File

@ -59,7 +59,9 @@ class SACPolicy(
config.input_normalization_params
)
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, input_normalization_params
config.input_shapes,
config.input_normalization_modes,
input_normalization_params,
)
else:
self.normalize_inputs = nn.Identity()
@ -90,7 +92,8 @@ class SACPolicy(
ensemble=Ensemble(
[
CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
input_dim=encoder_critic.output_dim
+ config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
@ -104,7 +107,8 @@ class SACPolicy(
ensemble=Ensemble(
[
CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
input_dim=encoder_critic.output_dim
+ config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
@ -120,13 +124,17 @@ class SACPolicy(
self.actor = Policy(
encoder=encoder_actor,
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
network=MLP(
input_dim=encoder_actor.output_dim, **config.actor_network_kwargs
),
action_dim=config.output_shapes["action"][0],
encoder_is_shared=config.shared_encoder,
**config.policy_kwargs,
)
if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
config.target_entropy = (
-np.prod(config.output_shapes["action"][0]) / 2
) # (-dim(A)/2)
# TODO (azouitine): Handle the case where the temparameter is a fixed
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
@ -153,7 +161,11 @@ class SACPolicy(
return actions
def critic_forward(
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False, observation_features: Tensor | None = None
self,
observations: dict[str, Tensor],
actions: Tensor,
use_target: bool = False,
observation_features: Tensor | None = None,
) -> Tensor:
"""Forward pass through a critic network ensemble
@ -173,21 +185,37 @@ class SACPolicy(
def update_target_networks(self):
"""Update target networks with exponential moving average"""
for target_param, param in zip(
self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=False
self.critic_target.parameters(),
self.critic_ensemble.parameters(),
strict=False,
):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
def compute_loss_critic(self, observations, actions, rewards, next_observations, done, observation_features: Tensor | None = None, next_observation_features: Tensor | None = None) -> Tensor:
def compute_loss_critic(
self,
observations,
actions,
rewards,
next_observations,
done,
observation_features: Tensor | None = None,
next_observation_features: Tensor | None = None,
) -> Tensor:
temperature = self.log_alpha.exp().item()
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
next_action_preds, next_log_probs, _ = self.actor(
next_observations, next_observation_features
)
# 2- compute q targets
q_targets = self.critic_forward(
observations=next_observations, actions=next_action_preds, use_target=True, observation_features=next_observation_features
observations=next_observations,
actions=next_action_preds,
use_target=True,
observation_features=next_observation_features,
)
# subsample critics to prevent overfitting if use high UTD (update to date)
@ -204,7 +232,12 @@ class SACPolicy(
td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs
q_preds = self.critic_forward(observations, actions, use_target=False, observation_features=observation_features)
q_preds = self.critic_forward(
observations,
actions,
use_target=False,
observation_features=observation_features,
)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
@ -219,20 +252,31 @@ class SACPolicy(
).sum()
return critics_loss
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
def compute_loss_temperature(
self, observations, observation_features: Tensor | None = None
) -> Tensor:
"""Compute the temperature loss"""
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.actor(observations, observation_features)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
temperature_loss = (
-self.log_alpha.exp() * (log_probs + self.config.target_entropy)
).mean()
return temperature_loss
def compute_loss_actor(self, observations, observation_features: Tensor | None = None) -> Tensor:
def compute_loss_actor(
self, observations, observation_features: Tensor | None = None
) -> Tensor:
temperature = self.log_alpha.exp().item()
actions_pi, log_probs, _ = self.actor(observations, observation_features)
q_preds = self.critic_forward(observations, actions_pi, use_target=False, observation_features=observation_features)
q_preds = self.critic_forward(
observations,
actions_pi,
use_target=False,
observation_features=observation_features,
)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
@ -259,7 +303,11 @@ class MLP(nn.Module):
if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[0]))
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
layers.append(
activations
if isinstance(activations, nn.Module)
else getattr(nn, activations)()
)
# Rest of the layers
for i in range(1, len(hidden_dims)):
@ -270,7 +318,9 @@ class MLP(nn.Module):
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[i]))
layers.append(
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
activations
if isinstance(activations, nn.Module)
else getattr(nn, activations)()
)
self.net = nn.Sequential(*layers)
@ -381,7 +431,11 @@ class CriticEnsemble(nn.Module):
actions = self.output_normalization(actions)["action"]
actions = actions.to(device)
obs_enc = observation_features if observation_features is not None else (observations if self.encoder is None else self.encoder(observations))
obs_enc = (
observation_features
if observation_features is not None
else (observations if self.encoder is None else self.encoder(observations))
)
inputs = torch.cat([obs_enc, actions], dim=-1)
q_values = self.ensemble(inputs) # [num_critics, B, 1]
@ -445,7 +499,11 @@ class Policy(nn.Module):
observation_features: torch.Tensor | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists
obs_enc = observation_features if observation_features is not None else (observations if self.encoder is None else self.encoder(observations))
obs_enc = (
observation_features
if observation_features is not None
else (observations if self.encoder is None else self.encoder(observations))
)
# Get network outputs
outputs = self.network(obs_enc)
@ -454,11 +512,15 @@ class Policy(nn.Module):
# Compute standard deviations
if self.fixed_std is None:
log_std = self.std_layer(outputs)
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
assert not torch.isnan(
log_std
).any(), "[ERROR] log_std became NaN after std_layer!"
if self.use_tanh_squash:
log_std = torch.tanh(log_std)
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
log_std = self.log_std_min + 0.5 * (
self.log_std_max - self.log_std_min
) * (log_std + 1.0)
else:
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
else:
@ -471,7 +533,9 @@ class Policy(nn.Module):
if self.use_tanh_squash:
actions = torch.tanh(x_t)
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
log_probs -= torch.log(
(1 - actions.pow(2)) + 1e-6
) # Adjust log-probs for Tanh
else:
actions = x_t # No Tanh; raw Gaussian sample
@ -518,12 +582,15 @@ class SACObservationEncoder(nn.Module):
freeze_image_encoder(self.image_enc_layers)
else:
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
self.all_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
self.all_image_keys = [
k for k in config.input_shapes if k.startswith("observation.image")
]
if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential(
nn.Linear(
in_features=config.input_shapes["observation.state"][0], out_features=config.latent_dim
in_features=config.input_shapes["observation.state"][0],
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
@ -544,7 +611,9 @@ class SACObservationEncoder(nn.Module):
self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
self.aggregation_layer = nn.Linear(
in_features=self.aggregation_size, out_features=config.latent_dim
)
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
@ -557,13 +626,19 @@ class SACObservationEncoder(nn.Module):
obs_dict = self.input_normalization(obs_dict)
# Batch all images along the batch dimension, then encode them.
if len(self.all_image_keys) > 0:
images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0)
images_batched = torch.cat(
[obs_dict[key] for key in self.all_image_keys], dim=0
)
images_batched = self.image_enc_layers(images_batched)
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
embeddings_chunks = torch.chunk(
images_batched, dim=0, chunks=len(self.all_image_keys)
)
feat.extend(embeddings_chunks)
if "observation.environment_state" in self.config.input_shapes:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
feat.append(
self.env_state_enc_layers(obs_dict["observation.environment_state"])
)
if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
@ -631,7 +706,9 @@ class PretrainedImageEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
self.image_enc_layers, self.image_enc_out_shape = (
self._load_pretrained_vision_encoder(config)
)
self.image_enc_proj = nn.Sequential(
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
@ -642,15 +719,21 @@ class PretrainedImageEncoder(nn.Module):
"""Set up CNN encoder"""
from transformers import AutoModel
self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True)
self.image_enc_layers = AutoModel.from_pretrained(
config.vision_encoder_name, trust_remote_code=True
)
# self.image_enc_layers.pooler = Identity()
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[
-1
] # Last channel dimension
elif hasattr(self.image_enc_layers, "fc"):
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
else:
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
raise ValueError(
"Unsupported vision encoder architecture, make sure you are using a CNN"
)
return self.image_enc_layers, self.image_enc_out_shape
def forward(self, x):
@ -673,7 +756,7 @@ def orthogonal_init():
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
super().__init__()
def forward(self, x):
return x
@ -701,7 +784,9 @@ class Ensemble(nn.Module):
return self.module(*args, **kwargs)
def forward(self, *args, **kwargs):
return torch.vmap(self._call, (0, None), randomness="different")(self.params, *args, **kwargs)
return torch.vmap(self._call, (0, None), randomness="different")(
self.params, *args, **kwargs
)
def __repr__(self):
return f"Vectorized {len(self)}x " + self._repr
@ -710,7 +795,9 @@ class Ensemble(nn.Module):
# TODO (azouitine): I think in our case this function is not usefull we should remove it
# after some investigation
# borrowed from tdmpc
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
def flatten_forward_unflatten(
fn: Callable[[Tensor], Tensor], image_tensor: Tensor
) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor.
Args:
@ -736,7 +823,9 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
for key, value in inner_dict.items():
converted_params[outer_key][key] = torch.tensor(value)
if "image" in outer_key:
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
converted_params[outer_key][key] = converted_params[outer_key][
key
].view(3, 1, 1)
return converted_params

View File

@ -183,7 +183,9 @@ class TDMPCConfig(PreTrainedConfig):
"If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1."
)
if not self.use_mpc:
raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.")
raise ValueError(
"If `n_action_steps > 1`, `use_mpc` must be set to `True`."
)
if self.n_action_steps > self.horizon:
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")

View File

@ -100,7 +100,9 @@ class TDMPCPolicy(PreTrainedPolicy):
"""
self._queues = {
"observation.state": deque(maxlen=1),
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
"action": deque(
maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)
),
}
if self.config.image_features:
self._queues["observation.image"] = deque(maxlen=1)
@ -189,7 +191,11 @@ class TDMPCPolicy(PreTrainedPolicy):
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
# trajectories.
z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
z = einops.repeat(
z,
"b d -> n b d",
n=self.config.n_gaussian_samples + self.config.n_pi_samples,
)
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
# algorithm.
@ -211,35 +217,47 @@ class TDMPCPolicy(PreTrainedPolicy):
self.config.action_feature.shape[0],
device=std.device,
)
gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1)
gaussian_actions = torch.clamp(
mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1
)
# Compute elite actions.
actions = torch.cat([gaussian_actions, pi_actions], dim=1)
value = self.estimate_value(z, actions).nan_to_num_(0)
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch)
elite_idxs = torch.topk(
value, self.config.n_elites, dim=0
).indices # (n_elites, batch)
elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch)
# (horizon, n_elites, batch, action_dim)
elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1)
elite_actions = actions.take_along_dim(
einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1
)
# Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites.
max_value = elite_value.max(0, keepdim=True)[0] # (1, batch)
# The weighting is a softmax over trajectory values. Note that this is not the same as the usage
# of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This
# makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²).
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
score = torch.exp(
self.config.elite_weighting_temperature * (elite_value - max_value)
)
score /= score.sum(axis=0, keepdim=True)
# (horizon, batch, action_dim)
_mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1)
_mean = torch.sum(
einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1
)
_std = torch.sqrt(
torch.sum(
einops.rearrange(score, "n b -> n b 1")
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2,
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d"))
** 2,
dim=1,
)
)
# Update mean with an exponential moving average, and std with a direct replacement.
mean = (
self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean
self.config.gaussian_mean_momentum * mean
+ (1 - self.config.gaussian_mean_momentum) * _mean
)
std = _std.clamp_(self.config.min_std, self.config.max_std)
@ -248,7 +266,9 @@ class TDMPCPolicy(PreTrainedPolicy):
# Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax
# scores from the last iteration.
actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)]
actions = elite_actions[
:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)
]
return actions
@ -271,7 +291,8 @@ class TDMPCPolicy(PreTrainedPolicy):
# of the FOWM paper.
if self.config.uncertainty_regularizer_coeff > 0:
regularization = -(
self.config.uncertainty_regularizer_coeff * self.model.Qs(z, actions[t]).std(0)
self.config.uncertainty_regularizer_coeff
* self.model.Qs(z, actions[t]).std(0)
)
else:
regularization = 0
@ -291,15 +312,22 @@ class TDMPCPolicy(PreTrainedPolicy):
if self.config.q_ensemble_size > 2:
G += (
running_discount
* torch.min(terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0)[
0
]
* torch.min(
terminal_values[
torch.randint(0, self.config.q_ensemble_size, size=(2,))
],
dim=0,
)[0]
)
else:
G += running_discount * torch.min(terminal_values, dim=0)[0]
# Finally, also regularize the terminal value.
if self.config.uncertainty_regularizer_coeff > 0:
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
G -= (
running_discount
* self.config.uncertainty_regularizer_coeff
* terminal_values.std(0)
)
return G
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
@ -329,7 +357,10 @@ class TDMPCPolicy(PreTrainedPolicy):
# Apply random image augmentations.
if self.config.image_features and self.config.max_random_shift_ratio > 0:
observations["observation.image"] = flatten_forward_unflatten(
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
partial(
random_shifts_aug,
max_random_shift_ratio=self.config.max_random_shift_ratio,
),
observations["observation.image"],
)
@ -347,14 +378,20 @@ class TDMPCPolicy(PreTrainedPolicy):
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
# gives us a next `z`.
batch_size = batch["index"].shape[0]
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
z_preds = torch.empty(
horizon + 1, batch_size, self.config.latent_dim, device=device
)
z_preds[0] = self.model.encode(current_observation)
reward_preds = torch.empty_like(reward, device=device)
for t in range(horizon):
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t])
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(
z_preds[t], action[t]
)
# Compute Q and V value predictions based on the latent rollout.
q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch)
q_preds_ensemble = self.model.Qs(
z_preds[:-1], action
) # (ensemble, horizon, batch)
v_preds = self.model.V(z_preds[:-1])
info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()})
@ -368,10 +405,14 @@ class TDMPCPolicy(PreTrainedPolicy):
# actions (not actions estimated by π).
# Note: Here we do not use self.model_target, but self.model. This is to follow the original code
# and the FOWM paper.
q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations))
q_targets = reward + self.config.discount * self.model.V(
self.model.encode(next_observations)
)
# From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we
# are using them to compute loss for V.
v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True)
v_targets = self.model_target.Qs(
z_preds[:-1].detach(), action, return_min=True
)
# Compute losses.
# Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the
@ -414,7 +455,9 @@ class TDMPCPolicy(PreTrainedPolicy):
temporal_loss_coeffs
* F.mse_loss(
q_preds_ensemble,
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
einops.repeat(
q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]
),
reduction="none",
).sum(0) # sum over ensemble
# `q_preds_ensemble` depends on the first observation and the actions.
@ -452,12 +495,14 @@ class TDMPCPolicy(PreTrainedPolicy):
z_preds = z_preds.detach()
# Use stopgrad for the advantage calculation.
with torch.no_grad():
advantage = self.model_target.Qs(z_preds[:-1], action, return_min=True) - self.model.V(
z_preds[:-1]
)
advantage = self.model_target.Qs(
z_preds[:-1], action, return_min=True
) - self.model.V(z_preds[:-1])
info["advantage"] = advantage[0]
# (t, b)
exp_advantage = torch.clamp(torch.exp(advantage * self.config.advantage_scaling), max=100.0)
exp_advantage = torch.clamp(
torch.exp(advantage * self.config.advantage_scaling), max=100.0
)
action_preds = self.model.pi(z_preds[:-1]) # (t, b, a)
# Calculate the MSE between the actions and the action predictions.
# Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation
@ -511,7 +556,9 @@ class TDMPCPolicy(PreTrainedPolicy):
# Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA
# update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
update_ema_parameters(
self.model_target, self.model, self.config.target_model_momentum
)
class TDMPCTOLD(nn.Module):
@ -598,7 +645,9 @@ class TDMPCTOLD(nn.Module):
"Sanity check. The last linear layer needs 0 initialization on weights."
)
nn.init.zeros_(m[-1].weight)
nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure
nn.init.zeros_(
m[-1].bias
) # this has already been done, but keep this line here for good measure
def encode(self, obs: dict[str, Tensor]) -> Tensor:
"""Encodes an observation into its latent representation."""
@ -702,11 +751,26 @@ class TDMPCObservationEncoder(nn.Module):
stride=2,
),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
nn.Conv2d(
config.image_encoder_hidden_dim,
config.image_encoder_hidden_dim,
5,
stride=2,
),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.Conv2d(
config.image_encoder_hidden_dim,
config.image_encoder_hidden_dim,
3,
stride=2,
),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.Conv2d(
config.image_encoder_hidden_dim,
config.image_encoder_hidden_dim,
3,
stride=2,
),
nn.ReLU(),
)
dummy_shape = (1, *next(iter(config.image_features.values())).shape)
@ -796,12 +860,17 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
"""Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param."""
for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True):
for (n_p_ema, p_ema), (n_p, p) in zip(
ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True
ema_module.named_parameters(recurse=False),
module.named_parameters(recurse=False),
strict=True,
):
assert n_p_ema == n_p, "Parameter names don't match for EMA model update"
if isinstance(p, dict):
raise RuntimeError("Dict parameter not supported")
if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad:
if (
isinstance(module, nn.modules.batchnorm._BatchNorm)
or not p.requires_grad
):
# Copy BatchNorm parameters, and non-trainable parameters directly.
p_ema.copy_(p.to(dtype=p_ema.dtype).data)
with torch.no_grad():
@ -809,7 +878,9 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha)
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
def flatten_forward_unflatten(
fn: Callable[[Tensor], Tensor], image_tensor: Tensor
) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor.
Args:

View File

@ -145,8 +145,14 @@ class VQBeTPolicy(PreTrainedPolicy):
)
if len(self._queues["action"]) == 0:
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
batch = {
k: torch.stack(list(self._queues[k]), dim=1)
for k in batch
if k in self._queues
}
actions = self.vqbet(batch, rollout=True)[
:, : self.config.action_chunk_size
]
# the dimension of returned action is (batch_size, action_chunk_size, action_dim)
actions = self.unnormalize_outputs({"action": actions})["action"]
@ -168,7 +174,9 @@ class VQBeTPolicy(PreTrainedPolicy):
# n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`.
# n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree).
loss, n_different_codes, n_different_combinations, recon_l1_error = (
self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"])
self.vqbet.action_head.discretize(
self.config.n_vqvae_training_steps, batch["action"]
)
)
return loss, {
"n_different_codes": n_different_codes,
@ -225,7 +233,9 @@ class SpatialSoftmax(nn.Module):
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and causes a small degradation in pc_success of pre-trained models.
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
pos_x, pos_y = np.meshgrid(
np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)
)
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
# register as buffer so it's moved to the correct device.
@ -339,7 +349,12 @@ class VQBeTModel(nn.Module):
num_tokens = self.config.n_action_pred_token + self.config.n_obs_steps - 1
self.register_buffer(
"select_target_actions_indices",
torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),
torch.row_stack(
[
torch.arange(i, i + self.config.action_chunk_size)
for i in range(num_tokens)
]
),
)
def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]:
@ -354,7 +369,11 @@ class VQBeTModel(nn.Module):
)
# Separate batch and sequence dims.
img_features = einops.rearrange(
img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images
img_features,
"(b s n) ... -> b s n ...",
b=batch_size,
s=n_obs_steps,
n=self.num_images,
)
# Arrange prior and current observation step tokens as shown in the class docstring.
@ -366,13 +385,19 @@ class VQBeTModel(nn.Module):
input_tokens.append(
self.state_projector(batch["observation.state"])
) # (batch, obs_step, projection dims)
input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps))
input_tokens.append(
einops.repeat(
self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps
)
)
# Interleave tokens by stacking and rearranging.
input_tokens = torch.stack(input_tokens, dim=2)
input_tokens = einops.rearrange(input_tokens, "b n t d -> b (n t) d")
len_additional_action_token = self.config.n_action_pred_token - 1
future_action_tokens = self.action_token.repeat(batch_size, len_additional_action_token, 1)
future_action_tokens = self.action_token.repeat(
batch_size, len_additional_action_token, 1
)
# add additional action query tokens for predicting future action chunks
input_tokens = torch.cat([input_tokens, future_action_tokens], dim=1)
@ -391,7 +416,11 @@ class VQBeTModel(nn.Module):
# Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional).
if len_additional_action_token > 0:
features = torch.cat(
[features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1
[
features[:, historical_act_pred_index],
features[:, -len_additional_action_token:],
],
dim=1,
)
else:
features = features[:, historical_act_pred_index]
@ -399,13 +428,15 @@ class VQBeTModel(nn.Module):
action_head_output = self.action_head(features)
# if rollout, VQ-BeT don't calculate loss
if rollout:
return action_head_output["predicted_action"][:, n_obs_steps - 1, :].reshape(
batch_size, self.config.action_chunk_size, -1
)
return action_head_output["predicted_action"][
:, n_obs_steps - 1, :
].reshape(batch_size, self.config.action_chunk_size, -1)
# else, it calculate overall loss (bin prediction loss, and offset loss)
else:
output = batch["action"][:, self.select_target_actions_indices]
loss = self.action_head.loss_fn(action_head_output, output, reduction="mean")
loss = self.action_head.loss_fn(
action_head_output, output, reduction="mean"
)
return action_head_output, loss
@ -440,7 +471,9 @@ class VQBeTHead(nn.Module):
else:
self.map_to_cbet_preds_bin = MLP(
in_channels=config.gpt_output_dim,
hidden_channels=[self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed],
hidden_channels=[
self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed
],
)
self.map_to_cbet_preds_offset = MLP(
in_channels=config.gpt_output_dim,
@ -467,7 +500,10 @@ class VQBeTHead(nn.Module):
loss, metric = self.vqvae_model.vqvae_forward(actions)
n_different_codes = sum(
[len(torch.unique(metric[2][:, i])) for i in range(self.vqvae_model.vqvae_num_layers)]
[
len(torch.unique(metric[2][:, i]))
for i in range(self.vqvae_model.vqvae_num_layers)
]
)
n_different_combinations = len(torch.unique(metric[2], dim=0))
recon_l1_error = metric[0].detach().cpu().item()
@ -514,7 +550,13 @@ class VQBeTHead(nn.Module):
cbet_secondary_logits = self.map_to_cbet_preds_secondary_bin(
torch.cat(
(x, F.one_hot(sampled_primary_centers, num_classes=self.config.vqvae_n_embed)),
(
x,
F.one_hot(
sampled_primary_centers,
num_classes=self.config.vqvae_n_embed,
),
),
axis=1,
)
)
@ -522,19 +564,29 @@ class VQBeTHead(nn.Module):
cbet_secondary_logits / self.config.bet_softmax_temperature, dim=-1
)
sampled_secondary_centers = einops.rearrange(
torch.multinomial(cbet_secondary_probs.view(-1, choices), num_samples=1),
torch.multinomial(
cbet_secondary_probs.view(-1, choices), num_samples=1
),
"(NT) 1 -> NT",
NT=NT,
)
sampled_centers = torch.stack((sampled_primary_centers, sampled_secondary_centers), axis=1)
cbet_logits = torch.stack([cbet_primary_logits, cbet_secondary_logits], dim=1)
sampled_centers = torch.stack(
(sampled_primary_centers, sampled_secondary_centers), axis=1
)
cbet_logits = torch.stack(
[cbet_primary_logits, cbet_secondary_logits], dim=1
)
# if self.config.sequentially_select is False, bin prediction head samples primary and secondary code at once.
else:
cbet_logits = self.map_to_cbet_preds_bin(x)
cbet_logits = einops.rearrange(
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
cbet_logits,
"(NT) (G C) -> (NT) G C",
G=self.vqvae_model.vqvae_num_layers,
)
cbet_probs = torch.softmax(
cbet_logits / self.config.bet_softmax_temperature, dim=-1
)
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
NT, G, choices = cbet_probs.shape
sampled_centers = einops.rearrange(
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
@ -554,9 +606,17 @@ class VQBeTHead(nn.Module):
sampled_offsets = sampled_offsets.sum(dim=1)
with torch.no_grad():
# Get the centroids (= vectors corresponding to the codes) of each layer to pass it through RVQ decoder
return_decoder_input = self.vqvae_model.get_embeddings_from_code(sampled_centers).clone().detach()
return_decoder_input = (
self.vqvae_model.get_embeddings_from_code(sampled_centers)
.clone()
.detach()
)
# pass the centroids through decoder to get actions.
decoded_action = self.vqvae_model.get_action_from_latent(return_decoder_input).clone().detach()
decoded_action = (
self.vqvae_model.get_action_from_latent(return_decoder_input)
.clone()
.detach()
)
# reshaped extracted offset to match with decoded centroids
sampled_offsets = einops.rearrange(
sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size
@ -605,7 +665,9 @@ class VQBeTHead(nn.Module):
# Figure out the loss for the actions.
# First, we need to find the closest cluster center for each ground truth action.
with torch.no_grad():
state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
state_vq, action_bins = self.vqvae_model.get_code(
action_seq
) # action_bins: NT, G
# Now we can compute the loss.
@ -628,8 +690,12 @@ class VQBeTHead(nn.Module):
+ cbet_loss2 * self.config.secondary_code_loss_weight
)
equal_primary_code_rate = torch.sum((action_bins[:, 0] == sampled_centers[:, 0]).int()) / (NT)
equal_secondary_code_rate = torch.sum((action_bins[:, 1] == sampled_centers[:, 1]).int()) / (NT)
equal_primary_code_rate = torch.sum(
(action_bins[:, 0] == sampled_centers[:, 0]).int()
) / (NT)
equal_secondary_code_rate = torch.sum(
(action_bins[:, 1] == sampled_centers[:, 1]).int()
) / (NT)
action_mse_error = torch.mean((action_seq - predicted_action) ** 2)
vq_action_error = torch.mean(torch.abs(action_seq - decoded_action))
@ -643,7 +709,9 @@ class VQBeTHead(nn.Module):
"classification_loss": cbet_loss.detach().cpu().item(),
"offset_loss": offset_loss.detach().cpu().item(),
"equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(),
"equal_secondary_code_rate": equal_secondary_code_rate.detach().cpu().item(),
"equal_secondary_code_rate": equal_secondary_code_rate.detach()
.cpu()
.item(),
"vq_action_error": vq_action_error.detach().cpu().item(),
"offset_action_error": offset_action_error.detach().cpu().item(),
"action_error_max": action_error_max.detach().cpu().item(),
@ -668,7 +736,9 @@ class VQBeTRgbEncoder(nn.Module):
# Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
self.maybe_random_crop = torchvision.transforms.RandomCrop(
config.crop_shape
)
else:
self.maybe_random_crop = self.center_crop
else:
@ -689,7 +759,9 @@ class VQBeTRgbEncoder(nn.Module):
self.backbone = _replace_submodules(
root_module=self.backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features // 16, num_channels=x.num_features
),
)
# Set up pooling and final layers.
@ -730,7 +802,9 @@ class VQBeTRgbEncoder(nn.Module):
def _replace_submodules(
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
root_module: nn.Module,
predicate: Callable[[nn.Module], bool],
func: Callable[[nn.Module], nn.Module],
) -> nn.Module:
"""
Args:
@ -743,7 +817,11 @@ def _replace_submodules(
if predicate(root_module):
return func(root_module)
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
replace_list = [
k.split(".")
for k, m in root_module.named_modules(remove_duplicate=True)
if predicate(m)
]
for *parents, k in replace_list:
parent_module = root_module
if len(parents) > 0:
@ -758,7 +836,9 @@ def _replace_submodules(
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
assert not any(
predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)
)
return root_module

View File

@ -123,9 +123,15 @@ class CausalSelfAttention(nn.Module):
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.gpt_hidden_dim, dim=2)
k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
1, 2
) # (B, nh, T, hs)
q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
1, 2
) # (B, nh, T, hs)
v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
1, 2
) # (B, nh, T, hs)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
@ -133,7 +139,9 @@ class CausalSelfAttention(nn.Module):
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
y = (
y.transpose(1, 2).contiguous().view(B, T, C)
) # re-assemble all head outputs side by side
# output projection
y = self.resid_dropout(self.c_proj(y))
@ -189,12 +197,16 @@ class GPT(nn.Module):
"ln_f": nn.LayerNorm(config.gpt_hidden_dim),
}
)
self.lm_head = nn.Linear(config.gpt_hidden_dim, config.gpt_output_dim, bias=False)
self.lm_head = nn.Linear(
config.gpt_hidden_dim, config.gpt_output_dim, bias=False
)
# init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
self.apply(self._init_weights)
for pn, p in self.named_parameters():
if pn.endswith("c_proj.weight"):
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer))
torch.nn.init.normal_(
p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer)
)
# report number of parameters
n_params = sum(p.numel() for p in self.parameters())
@ -208,11 +220,17 @@ class GPT(nn.Module):
)
# positional encodings that are added to the input embeddings
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(
0
) # shape (1, t)
# forward the GPT model itself
tok_emb = self.transformer.wte(input) # token embeddings of shape (b, t, gpt_hidden_dim)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim)
tok_emb = self.transformer.wte(
input
) # token embeddings of shape (b, t, gpt_hidden_dim)
pos_emb = self.transformer.wpe(
pos
) # position embeddings of shape (1, t, gpt_hidden_dim)
x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h:
x = block(x)
@ -237,7 +255,9 @@ class GPT(nn.Module):
# but want to use a smaller block size for some smaller, simpler model
assert gpt_block_size <= self.config.gpt_block_size
self.config.gpt_block_size = gpt_block_size
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:gpt_block_size])
self.transformer.wpe.weight = nn.Parameter(
self.transformer.wpe.weight[:gpt_block_size]
)
for block in self.transformer.h:
block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size]
@ -270,7 +290,9 @@ class GPT(nn.Module):
param_dict = dict(self.named_parameters())
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
assert (
len(inter_params) == 0
), "parameters {} made it into both decay/no_decay sets!".format(
str(inter_params)
)
assert len(param_dict.keys() - union_params) == 0, (
@ -368,8 +390,12 @@ class ResidualVQ(nn.Module):
codebook_input_dim = codebook_dim * heads
requires_projection = codebook_input_dim != dim
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
self.project_in = (
nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
)
self.num_quantizers = num_quantizers
@ -377,7 +403,10 @@ class ResidualVQ(nn.Module):
self.layers = nn.ModuleList(
[
VectorQuantize(
dim=codebook_dim, codebook_dim=codebook_dim, accept_image_fmap=accept_image_fmap, **kwargs
dim=codebook_dim,
codebook_dim=codebook_dim,
accept_image_fmap=accept_image_fmap,
**kwargs,
)
for _ in range(num_quantizers)
]
@ -448,7 +477,9 @@ class ResidualVQ(nn.Module):
return all_codes
def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp=None):
def forward(
self, x, indices=None, return_all_codes=False, sample_codebook_temp=None
):
"""
For given input tensor x, this function will return the quantized output, the indices of the quantized output, and the loss.
First, the input tensor x is projected to the codebook dimension. Then, the input tensor x is passed through Nq layers of VectorQuantize.
@ -477,13 +508,17 @@ class ResidualVQ(nn.Module):
)
ce_losses = []
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
should_quantize_dropout = (
self.training and self.quantize_dropout and not return_loss
)
# sample a layer index at which to dropout further residual quantization
# also prepare null indices and loss
if should_quantize_dropout:
rand_quantize_dropout_index = randrange(self.quantize_dropout_cutoff_index, num_quant)
rand_quantize_dropout_index = randrange(
self.quantize_dropout_cutoff_index, num_quant
)
if quant_dropout_multiple_of != 1:
rand_quantize_dropout_index = (
@ -492,14 +527,23 @@ class ResidualVQ(nn.Module):
- 1
)
null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2])
null_indices = torch.full(null_indices_shape, -1.0, device=device, dtype=torch.long)
null_indices_shape = (
(x.shape[0], *x.shape[-2:])
if self.accept_image_fmap
else tuple(x.shape[:2])
)
null_indices = torch.full(
null_indices_shape, -1.0, device=device, dtype=torch.long
)
null_loss = torch.full((1,), 0.0, device=device, dtype=x.dtype)
# go through the layers
for quantizer_index, layer in enumerate(self.layers):
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
if (
should_quantize_dropout
and quantizer_index > rand_quantize_dropout_index
):
all_indices.append(null_indices)
all_losses.append(null_loss)
continue
@ -539,7 +583,9 @@ class ResidualVQ(nn.Module):
# stack all losses and indices
all_losses, all_indices = map(partial(torch.stack, dim=-1), (all_losses, all_indices))
all_losses, all_indices = map(
partial(torch.stack, dim=-1), (all_losses, all_indices)
)
ret = (quantized_out, all_indices, all_losses)
@ -599,8 +645,12 @@ class VectorQuantize(nn.Module):
codebook_input_dim = codebook_dim * heads
requires_projection = codebook_input_dim != dim
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
self.project_in = (
nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
)
self.eps = eps
self.commitment_weight = commitment_weight
@ -614,10 +664,14 @@ class VectorQuantize(nn.Module):
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
assert not (ema_update and learnable_codebook), "learnable codebook not compatible with EMA update"
assert not (
ema_update and learnable_codebook
), "learnable codebook not compatible with EMA update"
assert 0 <= sync_update_v <= 1.0
assert not (sync_update_v > 0.0 and not learnable_codebook), "learnable codebook must be turned on"
assert not (
sync_update_v > 0.0 and not learnable_codebook
), "learnable codebook must be turned on"
self.sync_update_v = sync_update_v
@ -629,7 +683,9 @@ class VectorQuantize(nn.Module):
)
if sync_codebook is None:
sync_codebook = distributed.is_initialized() and distributed.get_world_size() > 1
sync_codebook = (
distributed.is_initialized() and distributed.get_world_size() > 1
)
codebook_kwargs = {
"dim": codebook_dim,
@ -794,11 +850,17 @@ class VectorQuantize(nn.Module):
# quantize again
quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)
quantize, embed_ind, distances = self._codebook(
x, **codebook_forward_kwargs
)
if self.training:
# determine code to use for commitment loss
maybe_detach = torch.detach if not self.learnable_codebook or freeze_codebook else identity
maybe_detach = (
torch.detach
if not self.learnable_codebook or freeze_codebook
else identity
)
commit_quantize = maybe_detach(quantize)
@ -808,7 +870,9 @@ class VectorQuantize(nn.Module):
if self.sync_update_v > 0.0:
# (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
quantize = quantize + self.sync_update_v * (quantize - quantize.detach())
quantize = quantize + self.sync_update_v * (
quantize - quantize.detach()
)
# function for calculating cross entropy loss to distance matrix
# used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss
@ -841,7 +905,9 @@ class VectorQuantize(nn.Module):
embed_ind = rearrange(embed_ind, "1 (b h) n -> b n h", h=heads)
if self.accept_image_fmap:
embed_ind = rearrange(embed_ind, "b (h w) ... -> b h w ...", h=height, w=width)
embed_ind = rearrange(
embed_ind, "b (h w) ... -> b h w ...", h=height, w=width
)
if only_one:
embed_ind = rearrange(embed_ind, "b 1 -> b")
@ -895,8 +961,12 @@ class VectorQuantize(nn.Module):
num_codes = codebook.shape[-2]
if (self.orthogonal_reg_max_codes is not None) and num_codes > self.orthogonal_reg_max_codes:
rand_ids = torch.randperm(num_codes, device=device)[: self.orthogonal_reg_max_codes]
if (
self.orthogonal_reg_max_codes is not None
) and num_codes > self.orthogonal_reg_max_codes:
rand_ids = torch.randperm(num_codes, device=device)[
: self.orthogonal_reg_max_codes
]
codebook = codebook[:, rand_ids]
orthogonal_reg_loss = orthogonal_loss_fn(codebook)
@ -928,7 +998,9 @@ class VectorQuantize(nn.Module):
# if masking, only return quantized for where mask has True
if mask is not None:
quantize = torch.where(rearrange(mask, "... -> ... 1"), quantize, orig_input)
quantize = torch.where(
rearrange(mask, "... -> ... 1"), quantize, orig_input
)
return quantize, embed_ind, loss
@ -1038,7 +1110,9 @@ def sample_vectors(samples, num):
def batched_sample_vectors(samples, num):
return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0)
return torch.stack(
[sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0
)
def pad_shape(shape, size, dim=0):
@ -1089,7 +1163,9 @@ def sample_vectors_distributed(local_samples, num):
all_num_samples = all_gather_sizes(local_samples, dim=0)
if rank == 0:
samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum())
samples_per_rank = sample_multinomial(
num, all_num_samples / all_num_samples.sum()
)
else:
samples_per_rank = torch.empty_like(all_num_samples)
@ -1202,7 +1278,9 @@ class EuclideanCodebook(nn.Module):
self.eps = eps
self.threshold_ema_dead_code = threshold_ema_dead_code
self.reset_cluster_size = (
reset_cluster_size if (reset_cluster_size is not None) else threshold_ema_dead_code
reset_cluster_size
if (reset_cluster_size is not None)
else threshold_ema_dead_code
)
assert callable(gumbel_sample)
@ -1213,8 +1291,14 @@ class EuclideanCodebook(nn.Module):
"kmeans init is not compatible with multiple codebooks in distributed environment for now"
)
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
self.sample_fn = (
sample_vectors_distributed
if use_ddp and sync_kmeans
else batched_sample_vectors
)
self.kmeans_all_reduce_fn = (
distributed.all_reduce if use_ddp and sync_kmeans else noop
)
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
self.register_buffer("initted", torch.Tensor([not kmeans_init]))
@ -1353,7 +1437,9 @@ class EuclideanCodebook(nn.Module):
distributed.all_reduce(variance_number)
batch_variance = variance_number / num_vectors
self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay)
self.update_with_decay(
"batch_variance", batch_variance, self.affine_param_batch_decay
)
def replace(self, batch_samples, batch_mask):
for ind, (samples, mask) in enumerate(
@ -1362,7 +1448,9 @@ class EuclideanCodebook(nn.Module):
if not torch.any(mask):
continue
sampled = self.sample_fn(rearrange(samples, "... -> 1 ..."), mask.sum().item())
sampled = self.sample_fn(
rearrange(samples, "... -> 1 ..."), mask.sum().item()
)
sampled = rearrange(sampled, "1 ... -> ...")
self.embed.data[ind][mask] = sampled
@ -1386,7 +1474,9 @@ class EuclideanCodebook(nn.Module):
def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
needs_codebook_dim = x.ndim < 4
sample_codebook_temp = (
sample_codebook_temp if (sample_codebook_temp is not None) else self.sample_codebook_temp
sample_codebook_temp
if (sample_codebook_temp is not None)
else self.sample_codebook_temp
)
x = x.float()
@ -1414,7 +1504,9 @@ class EuclideanCodebook(nn.Module):
if self.affine_param:
codebook_std = self.codebook_variance.clamp(min=1e-5).sqrt()
batch_std = self.batch_variance.clamp(min=1e-5).sqrt()
embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean
embed = (embed - self.codebook_mean) * (
batch_std / codebook_std
) + self.batch_mean
dist = -cdist(flatten, embed)
@ -1432,7 +1524,9 @@ class EuclideanCodebook(nn.Module):
if self.training and self.ema_update and not freeze_codebook:
if self.affine_param:
flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean
flatten = (flatten - self.batch_mean) * (
codebook_std / batch_std
) + self.codebook_mean
if mask is not None:
embed_onehot[~mask] = 0.0
@ -1455,7 +1549,9 @@ class EuclideanCodebook(nn.Module):
self.expire_codes_(x)
if needs_codebook_dim:
quantize, embed_ind = tuple(rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind))
quantize, embed_ind = tuple(
rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind)
)
dist = unpack_one(dist, ps, "h * d")

View File

@ -79,7 +79,9 @@ def save_image(img_array, serial_number, frame_index, images_dir):
img.save(str(path), quality=100)
logging.info(f"Saved image: {path}")
except Exception as e:
logging.error(f"Failed to save image for camera {serial_number} frame {frame_index}: {e}")
logging.error(
f"Failed to save image for camera {serial_number} frame {frame_index}: {e}"
)
def save_images_from_cameras(
@ -157,7 +159,9 @@ def save_images_from_cameras(
if time.perf_counter() - start_time > record_time_s:
break
print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
print(
f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}"
)
frame_index += 1
finally:
@ -275,7 +279,9 @@ class IntelRealSenseCamera:
f"Multiple {name} cameras have been detected. Please use their serial number to instantiate them."
)
name_to_serial_dict = {cam["name"]: cam["serial_number"] for cam in camera_infos}
name_to_serial_dict = {
cam["name"]: cam["serial_number"] for cam in camera_infos
}
cam_sn = name_to_serial_dict[name]
return cam_sn
@ -339,7 +345,9 @@ class IntelRealSenseCamera:
actual_height = color_profile.height()
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
if self.fps is not None and not math.isclose(
self.fps, actual_fps, rel_tol=1e-3
):
# Using `OSError` since it's a broad that encompasses issues related to device communication
raise OSError(
f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
@ -359,7 +367,9 @@ class IntelRealSenseCamera:
self.is_connected = True
def read(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
def read(
self, temporary_color: str | None = None
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
"""Read a frame from the camera returned in the format height x width x channels (e.g. 480 x 640 x 3)
of type `np.uint8`, contrarily to the pytorch format which is float channel first.
@ -386,11 +396,15 @@ class IntelRealSenseCamera:
color_frame = frame.get_color_frame()
if not color_frame:
raise OSError(f"Can't capture color image from IntelRealSenseCamera({self.serial_number}).")
raise OSError(
f"Can't capture color image from IntelRealSenseCamera({self.serial_number})."
)
color_image = np.asanyarray(color_frame.get_data())
requested_color_mode = self.color_mode if temporary_color is None else temporary_color
requested_color_mode = (
self.color_mode if temporary_color is None else temporary_color
)
if requested_color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided."
@ -418,7 +432,9 @@ class IntelRealSenseCamera:
if self.use_depth:
depth_frame = frame.get_depth_frame()
if not depth_frame:
raise OSError(f"Can't capture depth image from IntelRealSenseCamera({self.serial_number}).")
raise OSError(
f"Can't capture depth image from IntelRealSenseCamera({self.serial_number})."
)
depth_map = np.asanyarray(depth_frame.get_data())
@ -460,7 +476,9 @@ class IntelRealSenseCamera:
# TODO(rcadene, aliberts): intelrealsense has diverged compared to opencv over here
num_tries += 1
time.sleep(1 / self.fps)
if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()):
if num_tries > self.fps and (
self.thread.ident is None or not self.thread.is_alive()
):
raise Exception(
"The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called."
)

View File

@ -45,10 +45,14 @@ from lerobot.common.utils.utils import capture_timestamp_utc
MAX_OPENCV_INDEX = 60
def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
def find_cameras(
raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False
) -> list[dict]:
cameras = []
if platform.system() == "Linux":
print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports")
print(
"Linux detected. Finding available camera indices through scanning '/dev/video*' ports"
)
possible_ports = [str(port) for port in Path("/dev").glob("video*")]
ports = _find_cameras(possible_ports, mock=mock)
for port in ports:
@ -180,7 +184,9 @@ def save_images_from_cameras(
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
print(
f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}"
)
if time.perf_counter() - start_time > record_time_s:
break
@ -237,7 +243,9 @@ class OpenCVCamera:
if platform.system() == "Linux":
if isinstance(self.camera_index, int):
self.port = Path(f"/dev/video{self.camera_index}")
elif isinstance(self.camera_index, str) and is_valid_unix_path(self.camera_index):
elif isinstance(self.camera_index, str) and is_valid_unix_path(
self.camera_index
):
self.port = Path(self.camera_index)
# Retrieve the camera index from a potentially symlinked path
self.camera_index = get_camera_index_from_unix_port(self.port)
@ -283,7 +291,9 @@ class OpenCVCamera:
def connect(self):
if self.is_connected:
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
raise RobotDeviceAlreadyConnectedError(
f"OpenCVCamera({self.camera_index}) is already connected."
)
if self.mock:
import tests.cameras.mock_cv2 as cv2
@ -344,7 +354,9 @@ class OpenCVCamera:
actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
if self.fps is not None and not math.isclose(
self.fps, actual_fps, rel_tol=1e-3
):
# Using `OSError` since it's a broad that encompasses issues related to device communication
raise OSError(
f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}."
@ -386,7 +398,9 @@ class OpenCVCamera:
if not ret:
raise OSError(f"Can't capture color image from camera {self.camera_index}.")
requested_color_mode = self.color_mode if temporary_color_mode is None else temporary_color_mode
requested_color_mode = (
self.color_mode if temporary_color_mode is None else temporary_color_mode
)
if requested_color_mode not in ["rgb", "bgr"]:
raise ValueError(

View File

@ -39,7 +39,9 @@ from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import get_safe_torch_device, has_method
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
def log_control_info(
robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None
):
log_items = []
if episode_index is not None:
log_items.append(f"ep:{episode_index}")
@ -106,7 +108,9 @@ def predict_action(observation, policy, device, use_amp):
observation = copy(observation)
with (
torch.inference_mode(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
torch.autocast(device_type=device.type)
if device.type == "cuda" and use_amp
else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
@ -162,7 +166,9 @@ def init_keyboard_listener(assign_rewards=False):
print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True
elif key == keyboard.Key.left:
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
print(
"Left arrow key pressed. Exiting loop and rerecord the last episode..."
)
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
@ -256,7 +262,9 @@ def control_loop(
raise ValueError("You need to provide a task as argument in `single_task`.")
if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
raise ValueError(
f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps})."
)
timestamp = 0
start_episode_t = time.perf_counter()
@ -291,7 +299,9 @@ def control_loop(
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.imshow(
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
)
cv2.waitKey(1)
if fps is not None:
@ -361,7 +371,11 @@ def sanity_check_dataset_name(repo_id, policy_cfg):
def sanity_check_dataset_robot_compatibility(
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool, extra_features: dict = None
dataset: LeRobotDataset,
robot: Robot,
fps: int,
use_videos: bool,
extra_features: dict = None,
) -> None:
features_from_robot = get_features_from_robot(robot, use_videos)
if extra_features is not None:
@ -375,11 +389,14 @@ def sanity_check_dataset_robot_compatibility(
mismatches = []
for field, dataset_value, present_value in fields:
diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"])
diff = DeepDiff(
dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"]
)
if diff:
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
if mismatches:
raise ValueError(
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
"Dataset metadata compatibility check failed with mismatches:\n"
+ "\n".join(mismatches)
)

View File

@ -158,7 +158,9 @@ NUM_READ_RETRY = 10
NUM_WRITE_RETRY = 10
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
def convert_degrees_to_steps(
degrees: float | np.ndarray, models: str | list[str]
) -> np.ndarray:
"""This function converts the degree range to the step range for indicating motors rotation.
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
@ -384,7 +386,9 @@ class DynamixelMotorsBus:
indices = []
for idx in tqdm.tqdm(possible_ids):
try:
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
present_idx = self.read_with_motor_ids(
self.motor_models, [idx], "ID", num_retry=num_retry
)[0]
except ConnectionError:
continue
@ -400,7 +404,9 @@ class DynamixelMotorsBus:
def set_bus_baudrate(self, baudrate):
present_bus_baudrate = self.port_handler.getBaudRate()
if present_bus_baudrate != baudrate:
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
print(
f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}."
)
self.port_handler.setBaudRate(baudrate)
if self.port_handler.getBaudRate() != baudrate:
@ -421,7 +427,9 @@ class DynamixelMotorsBus:
def set_calibration(self, calibration: dict[str, list]):
self.calibration = calibration
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
def apply_calibration_autocorrect(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""This function applies the calibration, automatically detects out of range errors for motors values and attempts to correct.
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
@ -434,7 +442,9 @@ class DynamixelMotorsBus:
values = self.apply_calibration(values, motor_names)
return values
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
def apply_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
a "zero position" at 0 degree.
@ -509,7 +519,9 @@ class DynamixelMotorsBus:
return values
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
def autocorrect_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""This function automatically detects issues with values of motors after calibration, and correct for these issues.
Some motors might have values outside of expected maximum bounds after calibration.
@ -551,15 +563,23 @@ class DynamixelMotorsBus:
values[i] *= -1
# Convert from initial range to range [-180, 180] degrees
calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
calib_val = (
(values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
)
in_range = (calib_val > LOWER_BOUND_DEGREE) and (
calib_val < UPPER_BOUND_DEGREE
)
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
# (- (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= ((resolution // 2) - values[i] - homing_offset) / resolution
low_factor = (-(resolution // 2) - values[i] - homing_offset) / resolution
upp_factor = ((resolution // 2) - values[i] - homing_offset) / resolution
low_factor = (
-(resolution // 2) - values[i] - homing_offset
) / resolution
upp_factor = (
(resolution // 2) - values[i] - homing_offset
) / resolution
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
start_pos = self.calibration["start_pos"][calib_idx]
@ -567,7 +587,9 @@ class DynamixelMotorsBus:
# Convert from initial range to range [0, 100] in %
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
in_range = (calib_val > LOWER_BOUND_LINEAR) and (
calib_val < UPPER_BOUND_LINEAR
)
# Solve this inequality to find the factor to shift the range into [0, 100] %
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
@ -583,19 +605,27 @@ class DynamixelMotorsBus:
factor = math.ceil(low_factor)
if factor > upp_factor:
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
raise ValueError(
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
else:
factor = math.ceil(upp_factor)
if factor > low_factor:
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
raise ValueError(
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
out_of_range_str = (
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
in_range_str = (
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
logging.warning(
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
@ -605,7 +635,9 @@ class DynamixelMotorsBus:
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
self.calibration["homing_offset"][calib_idx] += resolution * factor
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
def revert_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""Inverse of `apply_calibration`."""
if motor_names is None:
motor_names = self.motor_names
@ -644,7 +676,9 @@ class DynamixelMotorsBus:
values = np.round(values).astype(np.int32)
return values
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
def read_with_motor_ids(
self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY
):
if self.mock:
import tests.motors.mock_dynamixel_sdk as dxl
else:
@ -746,7 +780,9 @@ class DynamixelMotorsBus:
values = self.apply_calibration_autocorrect(values, motor_names)
# log the number of seconds it took to read the data from the motors
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
delta_ts_name = get_log_name(
"delta_timestamp_s", "read", data_name, motor_names
)
self.logs[delta_ts_name] = time.perf_counter() - start_time
# log the utc time at which the data was received
@ -755,7 +791,9 @@ class DynamixelMotorsBus:
return values
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
def write_with_motor_ids(
self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY
):
if self.mock:
import tests.motors.mock_dynamixel_sdk as dxl
else:
@ -784,7 +822,12 @@ class DynamixelMotorsBus:
f"{self.packet_handler.getTxRxResult(comm)}"
)
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
def write(
self,
data_name,
values: int | float | np.ndarray,
motor_names: str | list[str] | None = None,
):
if not self.is_connected:
raise RobotDeviceNotConnectedError(
f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
@ -845,7 +888,9 @@ class DynamixelMotorsBus:
)
# log the number of seconds it took to write the data to the motors
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
delta_ts_name = get_log_name(
"delta_timestamp_s", "write", data_name, motor_names
)
self.logs[delta_ts_name] = time.perf_counter() - start_time
# TODO(rcadene): should we log the time before sending the write command?

View File

@ -137,7 +137,9 @@ NUM_READ_RETRY = 20
NUM_WRITE_RETRY = 20
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
def convert_degrees_to_steps(
degrees: float | np.ndarray, models: str | list[str]
) -> np.ndarray:
"""This function converts the degree range to the step range for indicating motors rotation.
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
@ -365,7 +367,9 @@ class FeetechMotorsBus:
indices = []
for idx in tqdm.tqdm(possible_ids):
try:
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
present_idx = self.read_with_motor_ids(
self.motor_models, [idx], "ID", num_retry=num_retry
)[0]
except ConnectionError:
continue
@ -381,7 +385,9 @@ class FeetechMotorsBus:
def set_bus_baudrate(self, baudrate):
present_bus_baudrate = self.port_handler.getBaudRate()
if present_bus_baudrate != baudrate:
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
print(
f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}."
)
self.port_handler.setBaudRate(baudrate)
if self.port_handler.getBaudRate() != baudrate:
@ -402,7 +408,9 @@ class FeetechMotorsBus:
def set_calibration(self, calibration: dict[str, list]):
self.calibration = calibration
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
def apply_calibration_autocorrect(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""This function apply the calibration, automatically detects out of range errors for motors values and attempt to correct.
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
@ -415,7 +423,9 @@ class FeetechMotorsBus:
values = self.apply_calibration(values, motor_names)
return values
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
def apply_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
a "zero position" at 0 degree.
@ -489,7 +499,9 @@ class FeetechMotorsBus:
return values
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
def autocorrect_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""This function automatically detects issues with values of motors after calibration, and correct for these issues.
Some motors might have values outside of expected maximum bounds after calibration.
@ -528,18 +540,26 @@ class FeetechMotorsBus:
values[i] *= -1
# Convert from initial range to range [-180, 180] degrees
calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
calib_val = (
(values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
)
in_range = (calib_val > LOWER_BOUND_DEGREE) and (
calib_val < UPPER_BOUND_DEGREE
)
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
# (- HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= (HALF_TURN_DEGREE / 180 * (resolution // 2) - values[i] - homing_offset) / resolution
low_factor = (
-HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
-HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2)
- values[i]
- homing_offset
) / resolution
upp_factor = (
HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2)
- values[i]
- homing_offset
) / resolution
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
@ -548,7 +568,9 @@ class FeetechMotorsBus:
# Convert from initial range to range [0, 100] in %
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
in_range = (calib_val > LOWER_BOUND_LINEAR) and (
calib_val < UPPER_BOUND_LINEAR
)
# Solve this inequality to find the factor to shift the range into [0, 100] %
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
@ -564,19 +586,27 @@ class FeetechMotorsBus:
factor = math.ceil(low_factor)
if factor > upp_factor:
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
raise ValueError(
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
else:
factor = math.ceil(upp_factor)
if factor > low_factor:
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
raise ValueError(
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
out_of_range_str = (
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
in_range_str = (
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
logging.warning(
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
@ -586,7 +616,9 @@ class FeetechMotorsBus:
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
self.calibration["homing_offset"][calib_idx] += resolution * factor
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
def revert_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""Inverse of `apply_calibration`."""
if motor_names is None:
motor_names = self.motor_names
@ -662,7 +694,9 @@ class FeetechMotorsBus:
return values
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
def read_with_motor_ids(
self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY
):
if self.mock:
import tests.motors.mock_scservo_sdk as scs
else:
@ -771,7 +805,9 @@ class FeetechMotorsBus:
values = self.apply_calibration_autocorrect(values, motor_names)
# log the number of seconds it took to read the data from the motors
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
delta_ts_name = get_log_name(
"delta_timestamp_s", "read", data_name, motor_names
)
self.logs[delta_ts_name] = time.perf_counter() - start_time
# log the utc time at which the data was received
@ -780,7 +816,9 @@ class FeetechMotorsBus:
return values
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
def write_with_motor_ids(
self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY
):
if self.mock:
import tests.motors.mock_scservo_sdk as scs
else:
@ -809,7 +847,12 @@ class FeetechMotorsBus:
f"{self.packet_handler.getTxRxResult(comm)}"
)
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
def write(
self,
data_name,
values: int | float | np.ndarray,
motor_names: str | list[str] | None = None,
):
if not self.is_connected:
raise RobotDeviceNotConnectedError(
f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
@ -870,7 +913,9 @@ class FeetechMotorsBus:
)
# log the number of seconds it took to write the data to the motors
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
delta_ts_name = get_log_name(
"delta_timestamp_s", "write", data_name, motor_names
)
self.logs[delta_ts_name] = time.perf_counter() - start_time
# TODO(rcadene): should we log the time before sending the write command?

View File

@ -24,9 +24,7 @@ from lerobot.common.robot_devices.motors.dynamixel import (
)
from lerobot.common.robot_devices.motors.utils import MotorsBus
URL_TEMPLATE = (
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
)
URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
# The following positions are provided in nominal degree range ]-180, +180[
# For more info on these constants, see comments in the code where they get used.
@ -37,7 +35,9 @@ ROTATED_POSITION_DEGREE = 90
def assert_drive_mode(drive_mode):
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
if not np.all(np.isin(drive_mode, [0, 1])):
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
raise ValueError(
f"`drive_mode` contains values other than 0 or 1: ({drive_mode})"
)
def apply_drive_mode(position, drive_mode):
@ -78,12 +78,16 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
```
"""
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError("To run calibration, the torque must be disabled on all motors.")
raise ValueError(
"To run calibration, the torque must be disabled on all motors."
)
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to zero position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
print(
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")
)
input("Press Enter to continue...")
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
@ -104,10 +108,15 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
# of the previous motor in the kinetic chain.
print("\nMove arm to rotated target position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
print(
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")
)
input("Press Enter to continue...")
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
rotated_target_pos = convert_degrees_to_steps(
ROTATED_POSITION_DEGREE, arm.motor_models
)
# Find drive mode by rotating each motor by a quarter of a turn.
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
@ -116,11 +125,15 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
# Re-compute homing offset to take into account drive mode
rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode)
rotated_nearest_pos = compute_nearest_rounded_position(rotated_drived_pos, arm.motor_models)
rotated_nearest_pos = compute_nearest_rounded_position(
rotated_drived_pos, arm.motor_models
)
homing_offset = rotated_target_pos - rotated_nearest_pos
print("\nMove arm to rest position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest"))
print(
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")
)
input("Press Enter to continue...")
print()

View File

@ -26,9 +26,7 @@ from lerobot.common.robot_devices.motors.feetech import (
)
from lerobot.common.robot_devices.motors.utils import MotorsBus
URL_TEMPLATE = (
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
)
URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
# The following positions are provided in nominal degree range ]-180, +180[
# For more info on these constants, see comments in the code where they get used.
@ -39,7 +37,9 @@ ROTATED_POSITION_DEGREE = 90
def assert_drive_mode(drive_mode):
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
if not np.all(np.isin(drive_mode, [0, 1])):
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
raise ValueError(
f"`drive_mode` contains values other than 0 or 1: ({drive_mode})"
)
def apply_drive_mode(position, drive_mode):
@ -140,7 +140,9 @@ def apply_offset(calib, offset):
return calib
def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
def run_arm_auto_calibration(
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
if robot_type == "so100":
return run_arm_auto_calibration_so100(arm, robot_type, arm_name, arm_type)
elif robot_type == "moss":
@ -149,18 +151,27 @@ def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm
raise ValueError(robot_type)
def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
def run_arm_auto_calibration_so100(
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError("To run calibration, the torque must be disabled on all motors.")
raise ValueError(
"To run calibration, the torque must be disabled on all motors."
)
if not (robot_type == "so100" and arm_type == "follower"):
raise NotImplementedError("Auto calibration only supports the follower of so100 arms for now.")
raise NotImplementedError(
"Auto calibration only supports the follower of so100 arms for now."
)
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to initial position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
print(
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")
)
input("Press Enter to continue...")
# Lower the acceleration of the motors (in [0,254])
@ -207,11 +218,16 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
print("Calibrate elbow_flex")
calib["elbow_flex"] = move_to_calibrate(
arm, "elbow_flex", positive_first=False, in_between_move_hook=in_between_move_hook
arm,
"elbow_flex",
positive_first=False,
in_between_move_hook=in_between_move_hook,
)
calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024)
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex")
arm.write(
"Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex"
)
time.sleep(1)
def in_between_move_hook():
@ -239,18 +255,30 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
}
arm.write("Goal_Position", list(positions.values()), list(positions.keys()))
arm.write("Goal_Position", round(calib["shoulder_lift"]["zero_pos"] - 1600), "shoulder_lift")
arm.write(
"Goal_Position",
round(calib["shoulder_lift"]["zero_pos"] - 1600),
"shoulder_lift",
)
time.sleep(2)
arm.write("Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex")
arm.write(
"Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex"
)
time.sleep(2)
arm.write("Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex")
arm.write(
"Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex"
)
time.sleep(2)
arm.write("Goal_Position", round(calib["gripper"]["end_pos"]), "gripper")
time.sleep(2)
print("Calibrate wrist_roll")
calib["wrist_roll"] = move_to_calibrate(
arm, "wrist_roll", invert_drive_mode=True, positive_first=False, while_move_hook=while_move_hook
arm,
"wrist_roll",
invert_drive_mode=True,
positive_first=False,
while_move_hook=while_move_hook,
)
arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"], "wrist_roll")
@ -260,7 +288,9 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex")
time.sleep(1)
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 2048, "elbow_flex")
arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift")
arm.write(
"Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift"
)
time.sleep(1)
arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan")
time.sleep(1)
@ -289,18 +319,27 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
return calib_dict
def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
def run_arm_auto_calibration_moss(
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError("To run calibration, the torque must be disabled on all motors.")
raise ValueError(
"To run calibration, the torque must be disabled on all motors."
)
if not (robot_type == "moss" and arm_type == "follower"):
raise NotImplementedError("Auto calibration only supports the follower of moss arms for now.")
raise NotImplementedError(
"Auto calibration only supports the follower of moss arms for now."
)
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to initial position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
print(
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")
)
input("Press Enter to continue...")
# Lower the acceleration of the motors (in [0,254])
@ -384,8 +423,12 @@ def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
time.sleep(1)
arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift")
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex")
arm.write(
"Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift"
)
arm.write(
"Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex"
)
time.sleep(2)
calib_modes = []
@ -412,7 +455,9 @@ def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str
return calib_dict
def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
def run_arm_manual_calibration(
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
"""This function ensures that a neural network trained on data collected on a given robot
can work on another robot. For instance before calibration, setting a same goal position
for each motor of two different robots will get two very different positions. But after calibration,
@ -435,12 +480,16 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a
```
"""
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError("To run calibration, the torque must be disabled on all motors.")
raise ValueError(
"To run calibration, the torque must be disabled on all motors."
)
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to zero position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
print(
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")
)
input("Press Enter to continue...")
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
@ -460,10 +509,15 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
# of the previous motor in the kinetic chain.
print("\nMove arm to rotated target position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
print(
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")
)
input("Press Enter to continue...")
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
rotated_target_pos = convert_degrees_to_steps(
ROTATED_POSITION_DEGREE, arm.motor_models
)
# Find drive mode by rotating each motor by a quarter of a turn.
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
@ -475,7 +529,9 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a
homing_offset = rotated_target_pos - rotated_drived_pos
print("\nMove arm to rest position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest"))
print(
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")
)
input("Press Enter to continue...")
print()

View File

@ -31,11 +31,16 @@ from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig
from lerobot.common.robot_devices.robots.utils import get_arm_id
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
)
def ensure_safe_goal_position(
goal_pos: torch.Tensor, present_pos: torch.Tensor, max_relative_target: float | list[float]
goal_pos: torch.Tensor,
present_pos: torch.Tensor,
max_relative_target: float | list[float],
):
# Cap relative action target magnitude for safety.
diff = goal_pos - present_pos
@ -277,7 +282,9 @@ class ManipulatorRobot:
# to squeeze the gripper and have it spring back to an open position on its own.
for name in self.leader_arms:
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper")
self.leader_arms[name].write(
"Goal_Position", self.config.gripper_open_degree, "gripper"
)
# Check both arms can be read
for name in self.follower_arms:
@ -309,18 +316,26 @@ class ManipulatorRobot:
print(f"Missing calibration file '{arm_calib_path}'")
if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
from lerobot.common.robot_devices.robots.dynamixel_calibration import run_arm_calibration
from lerobot.common.robot_devices.robots.dynamixel_calibration import (
run_arm_calibration,
)
calibration = run_arm_calibration(arm, self.robot_type, name, arm_type)
calibration = run_arm_calibration(
arm, self.robot_type, name, arm_type
)
elif self.robot_type in ["so100", "moss", "lekiwi"]:
from lerobot.common.robot_devices.robots.feetech_calibration import (
run_arm_manual_calibration,
)
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
calibration = run_arm_manual_calibration(
arm, self.robot_type, name, arm_type
)
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
print(
f"Calibration is done! Saving calibration file '{arm_calib_path}'"
)
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
with open(arm_calib_path, "w") as f:
json.dump(calibration, f)
@ -339,13 +354,17 @@ class ManipulatorRobot:
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError("To run set robot preset, the torque must be disabled on all motors.")
raise ValueError(
"To run set robot preset, the torque must be disabled on all motors."
)
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't
# rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm,
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
all_motors_except_gripper = [name for name in arm.motor_names if name != "gripper"]
all_motors_except_gripper = [
name for name in arm.motor_names if name != "gripper"
]
if len(all_motors_except_gripper) > 0:
# 4 corresponds to Extended Position on Koch motors
arm.write("Operating_Mode", 4, all_motors_except_gripper)
@ -374,7 +393,9 @@ class ManipulatorRobot:
# Enable torque on the gripper of the leader arms, and move it to 45 degrees,
# so that we can use it as a trigger to close the gripper of the follower arms.
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper")
self.leader_arms[name].write(
"Goal_Position", self.config.gripper_open_degree, "gripper"
)
def set_aloha_robot_preset(self):
def set_shadow_(arm):
@ -404,11 +425,15 @@ class ManipulatorRobot:
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
all_motors_except_gripper = [
name for name in self.follower_arms[name].motor_names if name != "gripper"
name
for name in self.follower_arms[name].motor_names
if name != "gripper"
]
if len(all_motors_except_gripper) > 0:
# 4 corresponds to Extended Position on Aloha motors
self.follower_arms[name].write("Operating_Mode", 4, all_motors_except_gripper)
self.follower_arms[name].write(
"Operating_Mode", 4, all_motors_except_gripper
)
# Use 'position control current based' for follower gripper to be limited by the limit of the current.
# It can grasp an object without forcing too much even tho,
@ -456,7 +481,9 @@ class ManipulatorRobot:
before_lread_t = time.perf_counter()
leader_pos[name] = self.leader_arms[name].read("Present_Position")
leader_pos[name] = torch.from_numpy(leader_pos[name])
self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t
self.logs[f"read_leader_{name}_pos_dt_s"] = (
time.perf_counter() - before_lread_t
)
# Send goal position to the follower
follower_goal_pos = {}
@ -477,14 +504,18 @@ class ManipulatorRobot:
if self.config.max_relative_target is not None:
present_pos = self.follower_arms[name].read("Present_Position")
present_pos = torch.from_numpy(present_pos)
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
goal_pos = ensure_safe_goal_position(
goal_pos, present_pos, self.config.max_relative_target
)
# Used when record_data=True
follower_goal_pos[name] = goal_pos
goal_pos = goal_pos.numpy().astype(np.float32)
self.follower_arms[name].write("Goal_Position", goal_pos)
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = (
time.perf_counter() - before_fwrite_t
)
# Early exit when recording data is not requested
if not record_data:
@ -497,7 +528,9 @@ class ManipulatorRobot:
before_fread_t = time.perf_counter()
follower_pos[name] = self.follower_arms[name].read("Present_Position")
follower_pos[name] = torch.from_numpy(follower_pos[name])
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
self.logs[f"read_follower_{name}_pos_dt_s"] = (
time.perf_counter() - before_fread_t
)
# Create state by concatenating follower current position
state = []
@ -519,8 +552,12 @@ class ManipulatorRobot:
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
"delta_timestamp_s"
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
# Populate output dictionaries
obs_dict, action_dict = {}, {}
@ -544,7 +581,9 @@ class ManipulatorRobot:
before_fread_t = time.perf_counter()
follower_pos[name] = self.follower_arms[name].read("Present_Position")
follower_pos[name] = torch.from_numpy(follower_pos[name])
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
self.logs[f"read_follower_{name}_pos_dt_s"] = (
time.perf_counter() - before_fread_t
)
# Create state by concatenating follower current position
state = []
@ -559,8 +598,12 @@ class ManipulatorRobot:
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
"delta_timestamp_s"
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
# Populate output dictionaries and format to pytorch
obs_dict = {}
@ -606,7 +649,9 @@ class ManipulatorRobot:
if self.config.max_relative_target is not None:
present_pos = self.follower_arms[name].read("Present_Position")
present_pos = torch.from_numpy(present_pos)
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
goal_pos = ensure_safe_goal_position(
goal_pos, present_pos, self.config.max_relative_target
)
# Save tensor to concat and return
action_sent.append(goal_pos)

View File

@ -52,7 +52,9 @@ class StretchRobot(StretchAPI):
def connect(self) -> None:
self.is_connected = self.startup()
if not self.is_connected:
print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'")
print(
"Another process is already using Stretch. Try running 'stretch_free_robot_process.py'"
)
raise ConnectionError()
for name in self.cameras:
@ -60,7 +62,9 @@ class StretchRobot(StretchAPI):
self.is_connected = self.is_connected and self.cameras[name].is_connected
if not self.is_connected:
print("Could not connect to the cameras, check that all cameras are plugged-in.")
print(
"Could not connect to the cameras, check that all cameras are plugged-in."
)
raise ConnectionError()
self.run_calibration()
@ -105,8 +109,12 @@ class StretchRobot(StretchAPI):
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
"delta_timestamp_s"
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
# Populate output dictionaries
obs_dict, action_dict = {}, {}
@ -150,8 +158,12 @@ class StretchRobot(StretchAPI):
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
"delta_timestamp_s"
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
# Populate output dictionaries
obs_dict = {}

View File

@ -48,7 +48,8 @@ class RobotDeviceNotConnectedError(Exception):
"""Exception raised when the robot device is not connected."""
def __init__(
self, message="This robot device is not connected. Try calling `robot_device.connect()` first."
self,
message="This robot device is not connected. Try calling `robot_device.connect()` first.",
):
self.message = message
super().__init__(self.message)

View File

@ -17,7 +17,9 @@ import importlib
import logging
def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
def is_package_available(
pkg_name: str, return_version: bool = False
) -> tuple[bool, str] | bool:
"""Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py
Check if the package spec exists and grab its version to avoid importing a local directory.
**Note:** this doesn't work for all packages.

View File

@ -28,7 +28,9 @@ def write_video(video_path, stacked_frames, fps):
# Filter out DeprecationWarnings raised from pkg_resources
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning
"ignore",
"pkg_resources is deprecated as an API",
category=DeprecationWarning,
)
imageio.mimsave(video_path, stacked_frames, fps=fps)

View File

@ -148,7 +148,10 @@ def _relative_path_between(path1: Path, path2: Path) -> Path:
except ValueError: # most likely because path1 is not a subpath of path2
common_parts = Path(osp.commonpath([path1, path2])).parts
return Path(
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
"/".join(
[".."] * (len(path2.parts) - len(common_parts))
+ list(path1.parts[len(common_parts) :])
)
)
@ -159,10 +162,26 @@ def print_cuda_memory_usage():
gc.collect()
# Also clear the cache if you want to fully release the memory
torch.cuda.empty_cache()
print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2))
print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2))
print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2))
print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2))
print(
"Current GPU Memory Allocated: {:.2f} MB".format(
torch.cuda.memory_allocated(0) / 1024**2
)
)
print(
"Maximum GPU Memory Allocated: {:.2f} MB".format(
torch.cuda.max_memory_allocated(0) / 1024**2
)
)
print(
"Current GPU Memory Reserved: {:.2f} MB".format(
torch.cuda.memory_reserved(0) / 1024**2
)
)
print(
"Maximum GPU Memory Reserved: {:.2f} MB".format(
torch.cuda.max_memory_reserved(0) / 1024**2
)
)
def capture_timestamp_utc():
@ -232,7 +251,12 @@ def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
class TimerManager:
def __init__(self, elapsed_time_list: list[float] | None = None, label="Elapsed time", log=True):
def __init__(
self,
elapsed_time_list: list[float] | None = None,
label="Elapsed time",
log=True,
):
self.label = label
self.elapsed_time_list = elapsed_time_list
self.log = log

View File

@ -9,7 +9,7 @@ env:
action_dim: 6
fps: ${fps}
device: mps
wrapper:
crop_params_dict:
observation.images.front: [102, 43, 358, 523]
@ -28,4 +28,4 @@ env:
reward_classifier:
pretrained_path: outputs/classifier/13-02-random-sample-resnet10-frozen/checkpoints/best/pretrained_model
config_path: lerobot/configs/policy/hilserl_classifier.yaml

View File

@ -66,7 +66,7 @@ policy:
observation.image: [3, 64, 64]
output_shapes:
action: [7]
camera_number: 1
# Normalization / Unnormalization
@ -79,7 +79,7 @@ policy:
# 1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00,
# -3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00,
# -6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01,
# 8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01]
# 8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01]
# max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400,
# 0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163,

View File

@ -108,20 +108,26 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
break
if motor_index == -1:
raise ValueError("No motors detected. Please ensure you have one motor connected.")
raise ValueError(
"No motors detected. Please ensure you have one motor connected."
)
print(f"Motor index found at: {motor_index}")
if brand == "feetech":
# Allows ID and BAUDRATE to be written in memory
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
motor_bus.write_with_motor_ids(
motor_bus.motor_models, motor_index, "Lock", 0
)
if baudrate != baudrate_des:
print(f"Setting its baudrate to {baudrate_des}")
baudrate_idx = list(series_baudrate_table.values()).index(baudrate_des)
# The write can fail, so we allow retries
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx)
motor_bus.write_with_motor_ids(
motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx
)
time.sleep(0.5)
motor_bus.set_bus_baudrate(baudrate_des)
present_baudrate_idx = motor_bus.read_with_motor_ids(
@ -136,7 +142,9 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "ID", motor_idx_des)
present_idx = motor_bus.read_with_motor_ids(motor_bus.motor_models, motor_idx_des, "ID", num_retry=2)
present_idx = motor_bus.read_with_motor_ids(
motor_bus.motor_models, motor_idx_des, "ID", num_retry=2
)
if present_idx != motor_idx_des:
raise OSError("Failed to write index.")
@ -164,12 +172,29 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=str, required=True, help="Motors bus port (e.g. dynamixel,feetech)")
parser.add_argument("--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)")
parser.add_argument("--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)")
parser.add_argument("--ID", type=int, required=True, help="Desired ID of the current motor (e.g. 1,2,3)")
parser.add_argument(
"--baudrate", type=int, default=1000000, help="Desired baudrate for the motor (default: 1000000)"
"--port",
type=str,
required=True,
help="Motors bus port (e.g. dynamixel,feetech)",
)
parser.add_argument(
"--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)"
)
parser.add_argument(
"--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)"
)
parser.add_argument(
"--ID",
type=int,
required=True,
help="Desired ID of the current motor (e.g. 1,2,3)",
)
parser.add_argument(
"--baudrate",
type=int,
default=1000000,
help="Desired baudrate for the motor (default: 1000000)",
)
args = parser.parse_args()

View File

@ -149,7 +149,11 @@ def init_sim_calibration(robot, cfg):
axis_directions = np.array(cfg.get("axis_directions", [1]))
offsets = np.array(cfg.get("offsets", [0])) * np.pi
return {"start_pos": start_pos, "axis_directions": axis_directions, "offsets": offsets}
return {
"start_pos": start_pos,
"axis_directions": axis_directions,
"offsets": offsets,
}
def real_positions_to_sim(real_positions, axis_directions, start_pos, offsets):
@ -170,7 +174,10 @@ def teleoperate(env, robot: Robot, process_action_fn, teleop_time_s=None):
leader_pos = robot.leader_arms.main.read("Present_Position")
action = process_action_fn(leader_pos)
env.step(np.expand_dims(action, 0))
if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s:
if (
teleop_time_s is not None
and time.perf_counter() - start_teleop_t > teleop_time_s
):
print("Teleoperation processes finished.")
break
@ -202,19 +209,27 @@ def record(
# Load pretrained policy
extra_features = (
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}}
if assign_rewards
else None
)
policy = None
if pretrained_policy_name_or_path is not None:
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
policy, policy_fps, device, use_amp = init_policy(
pretrained_policy_name_or_path, policy_overrides
)
if fps is None:
fps = policy_fps
logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).")
logging.warning(
f"No fps provided, so using the fps from policy config ({policy_fps})."
)
if policy is None and process_action_from_leader is None:
raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.")
raise ValueError(
"Either policy or process_action_fn has to be set to enable control in sim."
)
# initialize listener before sim env
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
@ -256,7 +271,11 @@ def record(
"shape": env.observation_space[obs_key].shape,
}
features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None}
features["action"] = {
"dtype": "float32",
"shape": env.action_space.shape,
"names": None,
}
features = {**features, **extra_features}
# Create empty dataset or load existing saved episodes
@ -357,7 +376,9 @@ def record(
if events["stop_recording"] or recorded_episodes >= num_episodes:
break
else:
logging.info("Waiting for a few seconds before starting next episode recording...")
logging.info(
"Waiting for a few seconds before starting next episode recording..."
)
busy_wait(3)
log_say("Stop recording", play_sounds, blocking=True)
@ -375,7 +396,12 @@ def record(
def replay(
env, root: Path, repo_id: str, episode: int, fps: int | None = None, local_files_only: bool = True
env,
root: Path,
repo_id: str,
episode: int,
fps: int | None = None,
local_files_only: bool = True,
):
env = env()
@ -422,7 +448,10 @@ if __name__ == "__main__":
parser_record = subparsers.add_parser("record", parents=[base_parser])
parser_record.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
"--fps",
type=none_or_int,
default=None,
help="Frames per second (set to None to disable)",
)
parser_record.add_argument(
"--root",
@ -448,7 +477,9 @@ if __name__ == "__main__":
required=True,
help="A description of the task preformed during recording that can be used as a language instruction.",
)
parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.")
parser_record.add_argument(
"--num-episodes", type=int, default=50, help="Number of episodes to record."
)
parser_record.add_argument(
"--run-compute-stats",
type=int,
@ -509,7 +540,10 @@ if __name__ == "__main__":
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
parser_replay.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
"--fps",
type=none_or_int,
default=None,
help="Frames per second (set to None to disable)",
)
parser_replay.add_argument(
"--root",
@ -523,7 +557,9 @@ if __name__ == "__main__":
default="lerobot/test",
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
)
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episodes to replay.")
parser_replay.add_argument(
"--episode", type=int, default=0, help="Index of the episodes to replay."
)
args = parser.parse_args()

View File

@ -59,7 +59,11 @@ np_version = np.__version__ if HAS_NP else "N/A"
torch_version = torch.__version__ if HAS_TORCH else "N/A"
torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A"
cuda_version = (
torch._C._cuda_getCompiledVersion()
if HAS_TORCH and torch.version.cuda is not None
else "N/A"
)
# TODO(aliberts): refactor into an actual command `lerobot env`
@ -77,7 +81,9 @@ def display_sys_info() -> dict:
"Using GPU in script?": "<fill in>",
# "Using distributed or parallel set-up in script?": "<fill in>",
}
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n")
print(
"\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n"
)
print(format_dict(info))
return info

View File

@ -170,7 +170,10 @@ def rollout(
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
# available of none of the envs finished.
if "final_info" in info:
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
successes = [
info["is_success"] if info is not None else False
for info in info["final_info"]
]
else:
successes = [False] * env.num_envs
@ -184,9 +187,13 @@ def rollout(
step += 1
running_success_rate = (
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean()
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any")
.numpy()
.mean()
)
progbar.set_postfix(
{"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"}
)
progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"})
progbar.update()
# Track the final observation.
@ -204,7 +211,9 @@ def rollout(
if return_observations:
stacked_observations = {}
for key in all_observations[0]:
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
stacked_observations[key] = torch.stack(
[obs[key] for obs in all_observations], dim=1
)
ret["observation"] = stacked_observations
if hasattr(policy, "use_original_modules"):
@ -266,7 +275,9 @@ def eval_policy(
return
n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs)
if isinstance(env, gym.vector.SyncVectorEnv):
ep_frames.append(np.stack([env.envs[i].render() for i in range(n_to_render_now)])) # noqa: B023
ep_frames.append(
np.stack([env.envs[i].render() for i in range(n_to_render_now)])
) # noqa: B023
elif isinstance(env, gym.vector.AsyncVectorEnv):
# Here we must render all frames and discard any we don't need.
ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
@ -278,7 +289,9 @@ def eval_policy(
episode_data: dict | None = None
# we dont want progress bar when we use slurm, since it clutters the logs
progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm())
progbar = trange(
n_batches, desc="Stepping through eval batches", disable=inside_slurm()
)
for batch_ix in progbar:
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
# step.
@ -289,7 +302,8 @@ def eval_policy(
seeds = None
else:
seeds = range(
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
start_seed + (batch_ix * env.num_envs),
start_seed + ((batch_ix + 1) * env.num_envs),
)
rollout_data = rollout(
env,
@ -307,13 +321,22 @@ def eval_policy(
# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
mask = (
torch.arange(n_steps)
<= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)
).int()
# Extend metrics.
batch_sum_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "sum")
batch_sum_rewards = einops.reduce(
(rollout_data["reward"] * mask), "b n -> b", "sum"
)
sum_rewards.extend(batch_sum_rewards.tolist())
batch_max_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "max")
batch_max_rewards = einops.reduce(
(rollout_data["reward"] * mask), "b n -> b", "max"
)
max_rewards.extend(batch_max_rewards.tolist())
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
batch_successes = einops.reduce(
(rollout_data["success"] * mask), "b n -> b", "any"
)
all_successes.extend(batch_successes.tolist())
if seeds:
all_seeds.extend(seeds)
@ -326,17 +349,27 @@ def eval_policy(
rollout_data,
done_indices,
start_episode_index=batch_ix * env.num_envs,
start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)),
start_data_index=(
0
if episode_data is None
else (episode_data["index"][-1].item() + 1)
),
fps=env.unwrapped.metadata["render_fps"],
)
if episode_data is None:
episode_data = this_episode_data
else:
# Some sanity checks to make sure we are correctly compiling the data.
assert episode_data["episode_index"][-1] + 1 == this_episode_data["episode_index"][0]
assert (
episode_data["episode_index"][-1] + 1
== this_episode_data["episode_index"][0]
)
assert episode_data["index"][-1] + 1 == this_episode_data["index"][0]
# Concatenate the episode data.
episode_data = {k: torch.cat([episode_data[k], this_episode_data[k]]) for k in episode_data}
episode_data = {
k: torch.cat([episode_data[k], this_episode_data[k]])
for k in episode_data
}
# Maybe render video for visualization.
if max_episodes_rendered > 0 and len(ep_frames) > 0:
@ -354,7 +387,9 @@ def eval_policy(
target=write_video,
args=(
str(video_path),
stacked_frames[: done_index + 1], # + 1 to capture the last observation
stacked_frames[
: done_index + 1
], # + 1 to capture the last observation
env.unwrapped.metadata["render_fps"],
),
)
@ -363,7 +398,9 @@ def eval_policy(
n_episodes_rendered += 1
progbar.set_postfix(
{"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"}
{
"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"
}
)
# Wait till all video rendering threads are done.
@ -409,7 +446,11 @@ def eval_policy(
def _compile_episode_data(
rollout_data: dict, done_indices: Tensor, start_episode_index: int, start_data_index: int, fps: float
rollout_data: dict,
done_indices: Tensor,
start_episode_index: int,
start_data_index: int,
fps: float,
) -> dict:
"""Convenience function for `eval_policy(return_episode_data=True)`
@ -427,12 +468,16 @@ def _compile_episode_data(
# Here we do `num_frames - 1` as we don't want to include the last observation frame just yet.
ep_dict = {
"action": rollout_data["action"][ep_ix, : num_frames - 1],
"episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)),
"episode_index": torch.tensor(
[start_episode_index + ep_ix] * (num_frames - 1)
),
"frame_index": torch.arange(0, num_frames - 1, 1),
"timestamp": torch.arange(0, num_frames - 1, 1) / fps,
"next.done": rollout_data["done"][ep_ix, : num_frames - 1],
"next.success": rollout_data["success"][ep_ix, : num_frames - 1],
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32),
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(
torch.float32
),
}
# For the last observation frame, all other keys will just be copy padded.
@ -448,7 +493,9 @@ def _compile_episode_data(
for key in ep_dicts[0]:
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
data_dict["index"] = torch.arange(
start_data_index, start_data_index + total_frames, 1
)
return data_dict

View File

@ -46,7 +46,11 @@ import torch
from tqdm import trange
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position
from lerobot.common.robot_devices.control_utils import (
busy_wait,
is_headless,
reset_follower_position,
)
from lerobot.common.robot_devices.robots.factory import Robot, make_robot
from lerobot.common.utils.utils import (
init_hydra_config,
@ -60,13 +64,19 @@ def get_classifier(pretrained_path, config_path):
return
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
ClassifierConfig,
)
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
)
cfg = init_hydra_config(config_path)
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
classifier_config.num_cameras = len(
cfg.training.image_keys
) # TODO automate these paths
model = Classifier(classifier_config)
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
model = model.to("mps")
@ -151,11 +161,17 @@ def rollout(
images = []
for key in image_keys:
if display_cameras:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.imshow(
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
)
cv2.waitKey(1)
images.append(observation[key].to("mps"))
reward = reward_classifier.predict_reward(images) if reward_classifier is not None else 0.0
reward = (
reward_classifier.predict_reward(images)
if reward_classifier is not None
else 0.0
)
all_rewards.append(reward)
# print("REWARD : ", reward)
@ -219,11 +235,19 @@ def eval_policy(
start_eval = time.perf_counter()
progbar = trange(n_episodes, desc="Evaluating policy on real robot")
reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file)
reward_classifier = get_classifier(
reward_classifier_pretrained_path, reward_classifier_config_file
)
for _ in progbar:
rollout_data = rollout(
robot, policy, reward_classifier, fps, control_time_s, use_amp, display_cameras
robot,
policy,
reward_classifier,
fps,
control_time_s,
use_amp,
display_cameras,
)
rollouts.append(rollout_data)
@ -289,7 +313,9 @@ def init_keyboard_listener():
print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True
elif key == keyboard.Key.left:
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
print(
"Left arrow key pressed. Exiting loop and rerecord the last episode..."
)
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.space:
@ -301,7 +327,10 @@ def init_keyboard_listener():
"Place the leader in similar pose to the follower and press space again."
)
events["pause_policy"] = True
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
log_say(
"Human intervention stage. Get ready to take over.",
play_sounds=True,
)
else:
events["human_intervention_step"] = True
print("Space key pressed. Human intervention starting.")
@ -351,7 +380,9 @@ if __name__ == "__main__":
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
),
)
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
parser.add_argument(
"--revision", help="Optionally provide the Hugging Face Hub revision ID."
)
parser.add_argument(
"--out-dir",
help=(
@ -360,7 +391,8 @@ if __name__ == "__main__":
),
)
parser.add_argument(
"--display-cameras", help=("Whether to display the camera feed while the rollout is happening")
"--display-cameras",
help=("Whether to display the camera feed while the rollout is happening"),
)
parser.add_argument(
"--reward-classifier-pretrained-path",

View File

@ -45,9 +45,13 @@ def find_port():
print(f"The port of this MotorsBus is '{port}'")
print("Reconnect the USB cable.")
elif len(ports_diff) == 0:
raise OSError(f"Could not detect the port. No difference was found ({ports_diff}).")
raise OSError(
f"Could not detect the port. No difference was found ({ports_diff})."
)
else:
raise OSError(f"Could not detect the port. More than one port was found ({ports_diff}).")
raise OSError(
f"Could not detect the port. More than one port was found ({ports_diff})."
)
if __name__ == "__main__":

View File

@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import random
from typing import Any, Callable, Optional, Sequence, TypedDict
import io
@ -737,7 +736,6 @@ def concatenate_batch_transitions(
if __name__ == "__main__":
import numpy as np
from tempfile import TemporaryDirectory
# ===== Test 1: Create and use a synthetic ReplayBuffer =====
@ -1139,7 +1137,7 @@ if __name__ == "__main__":
savings_percent = (std_mem - opt_mem) / std_mem * 100
print(f"\nMemory optimization result:")
print("\nMemory optimization result:")
print(f"- Standard buffer state memory: {std_mem / (1024 * 1024):.2f} MB")
print(f"- Optimized buffer state memory: {opt_mem / (1024 * 1024):.2f} MB")
print(f"- Memory savings for state tensors: {savings_percent:.1f}%")

View File

@ -225,7 +225,9 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.")
parser = argparse.ArgumentParser(
description="Crop rectangular ROIs from a LeRobot dataset."
)
parser.add_argument(
"--repo-id",
type=str,
@ -247,7 +249,9 @@ if __name__ == "__main__":
args = parser.parse_args()
local_files_only = args.root is not None
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=local_files_only)
dataset = LeRobotDataset(
repo_id=args.repo_id, root=args.root, local_files_only=local_files_only
)
images = get_image_from_lerobot_dataset(dataset)
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
@ -256,7 +260,7 @@ if __name__ == "__main__":
if args.crop_params_path is None:
rois = select_square_roi_for_images(images)
else:
with open(args.crop_params_path, "r") as f:
with open(args.crop_params_path) as f:
rois = json.load(f)
# rois = {

View File

@ -31,7 +31,9 @@ def find_joint_bounds(
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.imshow(
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
)
cv2.waitKey(1)
timestamp = time.perf_counter() - start_episode_t
@ -57,7 +59,12 @@ if __name__ == "__main__":
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds")
parser.add_argument(
"--control-time-s",
type=float,
default=20,
help="Maximum episode length in seconds",
)
args = parser.parse_args()
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)

View File

@ -146,7 +146,7 @@ def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
def initialize_replay_buffer(
cfg: DictConfig, logger: Logger, device: str, storage_device:str
cfg: DictConfig, logger: Logger, device: str, storage_device: str
) -> ReplayBuffer:
if not cfg.resume:
return ReplayBuffer(

View File

@ -10,7 +10,9 @@ from typing import Any
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, torch.Tensor]:
def preprocess_maniskill_observation(
observations: dict[str, np.ndarray],
) -> dict[str, torch.Tensor]:
"""Convert environment observation to LeRobot format observation.
Args:
observation: Dictionary of observation batches from a Gym vector environment.
@ -62,7 +64,9 @@ class ManiSkillCompat(gym.Wrapper):
new_action_space_shape = env.action_space.shape[-1]
new_low = np.squeeze(env.action_space.low, axis=0)
new_high = np.squeeze(env.action_space.high, axis=0)
self.action_space = gym.spaces.Box(low=new_low, high=new_high, shape=(new_action_space_shape,))
self.action_space = gym.spaces.Box(
low=new_low, high=new_high, shape=(new_action_space_shape,)
)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
@ -81,7 +85,9 @@ class ManiSkillCompat(gym.Wrapper):
class ManiSkillActionWrapper(gym.ActionWrapper):
def __init__(self, env):
super().__init__(env)
self.action_space = gym.spaces.Tuple(spaces=(env.action_space, gym.spaces.Discrete(2)))
self.action_space = gym.spaces.Tuple(
spaces=(env.action_space, gym.spaces.Discrete(2))
)
def action(self, action):
action, telop = action
@ -95,7 +101,9 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper):
action_space_agent: gym.spaces.Box = env.action_space[0]
action_space_agent.low = action_space_agent.low * multiply_factor
action_space_agent.high = action_space_agent.high * multiply_factor
self.action_space = gym.spaces.Tuple(spaces=(action_space_agent, gym.spaces.Discrete(2)))
self.action_space = gym.spaces.Tuple(
spaces=(action_space_agent, gym.spaces.Discrete(2))
)
def step(self, action):
if isinstance(action, tuple):
@ -137,7 +145,9 @@ def make_maniskill(
env = ManiSkillObservationWrapper(env, device=cfg.env.device)
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
env._max_episode_steps = env.max_episode_steps = (
50 # gym_utils.find_max_episode_steps_value(env)
)
env.unwrapped.metadata["render_fps"] = 20
env = ManiSkillCompat(env)
env = ManiSkillActionWrapper(env)
@ -149,10 +159,11 @@ def make_maniskill(
if __name__ == "__main__":
import argparse
import hydra
from omegaconf import OmegaConf
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="lerobot/configs/env/maniskill_example.yaml")
parser.add_argument(
"--config", type=str, default="lerobot/configs/env/maniskill_example.yaml"
)
args = parser.parse_args()
# Initialize config

View File

@ -73,7 +73,9 @@ def make_optimizer_and_scheduler(cfg, policy):
},
]
optimizer = torch.optim.AdamW(
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
optimizer_params_dicts,
lr=cfg.training.lr,
weight_decay=cfg.training.weight_decay,
)
lr_scheduler = None
elif cfg.policy.name == "diffusion":
@ -100,14 +102,23 @@ def make_optimizer_and_scheduler(cfg, policy):
optimizer = torch.optim.Adam(
[
{"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
{"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr},
{"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr},
{
"params": policy.critic_ensemble.parameters(),
"lr": policy.config.critic_lr,
},
{
"params": policy.temperature.parameters(),
"lr": policy.config.temperature_lr,
},
]
)
lr_scheduler = None
elif cfg.policy.name == "vqbet":
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
from lerobot.common.policies.vqbet.modeling_vqbet import (
VQBeTOptimizer,
VQBeTScheduler,
)
optimizer = VQBeTOptimizer(policy, cfg)
lr_scheduler = VQBeTScheduler(optimizer, cfg)
@ -214,7 +225,9 @@ def train(cfg: TrainPipelineConfig):
if cfg.resume:
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_learnable_params = sum(
p.numel() for p in policy.parameters() if p.requires_grad
)
num_total_params = sum(p.numel() for p in policy.parameters())
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")

View File

@ -14,7 +14,6 @@
import logging
import time
from contextlib import nullcontext
from pathlib import Path
from pprint import pformat
import hydra
@ -28,14 +27,16 @@ from termcolor import colored
from torch import optim
from torch.autograd import profiler
from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler, random_split
from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler
from tqdm import tqdm
from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
ClassifierConfig,
)
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.common.utils.utils import (
format_big_number,
@ -50,7 +51,11 @@ def get_model(cfg, logger): # noqa I001
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
model = Classifier(classifier_config)
if cfg.resume:
model.load_state_dict(Classifier.from_pretrained(str(logger.last_pretrained_model_dir)).state_dict())
model.load_state_dict(
Classifier.from_pretrained(
str(logger.last_pretrained_model_dir)
).state_dict()
)
return model
@ -62,7 +67,9 @@ def create_balanced_sampler(dataset, cfg):
class_weights = 1.0 / counts.float()
sample_weights = class_weights[labels]
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
return WeightedRandomSampler(
weights=sample_weights, num_samples=len(sample_weights), replacement=True
)
def support_amp(device: torch.device, cfg: DictConfig) -> bool:
@ -71,7 +78,9 @@ def support_amp(device: torch.device, cfg: DictConfig) -> bool:
return cfg.training.use_amp and device.type in ("cuda", "cpu")
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
def train_epoch(
model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg
):
# Single epoch training loop with AMP support and progress tracking
model.train()
correct = 0
@ -85,7 +94,11 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
labels = batch[cfg.training.label_key].float().to(device)
# Forward pass with optional AMP
with torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext():
with (
torch.autocast(device_type=device.type)
if support_amp(device, cfg)
else nullcontext()
):
outputs = model(images)
loss = criterion(outputs.logits, labels)
@ -130,7 +143,9 @@ def validate(model, val_loader, criterion, device, logger, cfg):
with (
torch.no_grad(),
torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(),
torch.autocast(device_type=device.type)
if support_amp(device, cfg)
else nullcontext(),
):
for batch in tqdm(val_loader, desc="Validation"):
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
@ -143,7 +158,9 @@ def validate(model, val_loader, criterion, device, logger, cfg):
):
outputs = model(images)
inference_times.append(
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
next(
x for x in prof.key_averages() if x.key == "model_inference"
).cpu_time
)
else:
outputs = model(images)
@ -161,16 +178,24 @@ def validate(model, val_loader, criterion, device, logger, cfg):
# Log sample predictions for visualization
if len(samples) < cfg.eval.num_samples_to_log:
for i in range(min(cfg.eval.num_samples_to_log - len(samples), len(images))):
for i in range(
min(cfg.eval.num_samples_to_log - len(samples), len(images))
):
if model.config.num_classes == 2:
confidence = round(outputs.probabilities[i].item(), 3)
else:
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
confidence = [
round(prob, 3) for prob in outputs.probabilities[i].tolist()
]
samples.append(
{
**{
f"image_{img_key}": wandb.Image(images[img_idx][i].cpu())
for img_idx, img_key in enumerate(cfg.training.image_keys)
f"image_{img_key}": wandb.Image(
images[img_idx][i].cpu()
)
for img_idx, img_key in enumerate(
cfg.training.image_keys
)
},
"true_label": labels[i].item(),
"predicted": predictions[i].item(),
@ -238,15 +263,24 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step):
elif device.type == "mps":
torch.mps.synchronize()
with profiler.profile(record_shapes=True) as prof, profiler.record_function("model_inference"):
with (
profiler.profile(record_shapes=True) as prof,
profiler.record_function("model_inference"),
):
_ = model(x)
inference_times.append(
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
next(
x for x in prof.key_averages() if x.key == "model_inference"
).cpu_time
)
inference_times = np.array(inference_times)
avg, median, std = inference_times.mean(), np.median(inference_times), inference_times.std()
avg, median, std = (
inference_times.mean(),
np.median(inference_times),
inference_times.std(),
)
print(
f"Inference time mean: {avg:.2f} us, median: {median:.2f} us, std: {std:.2f} us, with {iters} iterations on {device.type} device"
)
@ -264,7 +298,11 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step):
return avg, median, std
@hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier")
@hydra.main(
version_base="1.2",
config_path="../configs/policy",
config_name="hilserl_classifier",
)
def train(cfg: DictConfig) -> None:
# Main training pipeline with support for resuming training
logging.info(OmegaConf.to_yaml(cfg))
@ -278,7 +316,9 @@ def train(cfg: DictConfig) -> None:
# Setup dataset and dataloaders
dataset = LeRobotDataset(
cfg.dataset_repo_id, root=cfg.dataset_root, local_files_only=cfg.local_files_only
cfg.dataset_repo_id,
root=cfg.dataset_root,
local_files_only=cfg.local_files_only,
)
logging.info(f"Dataset size: {len(dataset)}")
@ -314,7 +354,9 @@ def train(cfg: DictConfig) -> None:
"You have set resume=True, but there is no model checkpoint in "
f"{Logger.get_last_checkpoint_dir(out_dir)}"
)
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
checkpoint_cfg_path = str(
Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml"
)
logging.info(
colored(
"You have set resume=True, indicating that you wish to resume a run",
@ -327,7 +369,9 @@ def train(cfg: DictConfig) -> None:
# Check for differences between the checkpoint configuration and provided configuration.
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
resolve_delta_timestamps(cfg)
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
diff = DeepDiff(
OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)
)
# Ignore the `resume` and parameters.
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
del diff["values_changed"]["root['resume']"]
@ -346,7 +390,11 @@ def train(cfg: DictConfig) -> None:
optimizer = optim.AdamW(model.parameters(), lr=cfg.training.learning_rate)
# Use BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multi-class
criterion = nn.BCEWithLogitsLoss() if model.config.num_classes == 2 else nn.CrossEntropyLoss()
criterion = (
nn.BCEWithLogitsLoss()
if model.config.num_classes == 2
else nn.CrossEntropyLoss()
)
grad_scaler = GradScaler(enabled=cfg.training.use_amp)
# Log model parameters
@ -362,7 +410,17 @@ def train(cfg: DictConfig) -> None:
for epoch in range(cfg.training.num_epochs):
logging.info(f"\nEpoch {epoch+1}/{cfg.training.num_epochs}")
train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg)
train_epoch(
model,
train_loader,
criterion,
optimizer,
grad_scaler,
device,
logger,
step,
cfg,
)
# Periodic validation
if cfg.training.eval_freq > 0 and (epoch + 1) % cfg.training.eval_freq == 0:

View File

@ -22,7 +22,6 @@ from typing import Callable, Optional, Sequence, TypedDict
import hydra
import torch
import torch.nn.functional as F
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from torch import nn
from tqdm import tqdm
@ -30,20 +29,17 @@ from tqdm import tqdm
# TODO: Remove the import of maniskill
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.envs.factory import make_env, make_maniskill_env
from lerobot.common.envs.utils import preprocess_maniskill_observation, preprocess_observation
from lerobot.common.envs.factory import make_maniskill_env
from lerobot.common.envs.utils import preprocess_maniskill_observation
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.utils import (
format_big_number,
get_safe_torch_device,
init_hydra_config,
init_logging,
set_global_seed,
)
from lerobot.scripts.eval import eval_policy
def make_optimizers_and_scheduler(cfg, policy):
@ -56,7 +52,9 @@ def make_optimizers_and_scheduler(cfg, policy):
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
)
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
optimizer_temperature = torch.optim.Adam(
params=[policy.log_alpha], lr=policy.config.critic_lr
)
lr_scheduler = None
optimizers = {
"actor": optimizer_actor,
@ -108,7 +106,9 @@ def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Te
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
# Gather pixels
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
cropped_hwcn = images_hwcn[
torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :
]
# cropped_hwcn => (B, crop_h, crop_w, C)
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
@ -198,8 +198,12 @@ class ReplayBuffer:
"""
# We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from
# a replay buffer than from a lerobot dataset.
replay_buffer = cls(capacity=len(lerobot_dataset), device=device, state_keys=state_keys)
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
replay_buffer = cls(
capacity=len(lerobot_dataset), device=device, state_keys=state_keys
)
list_transition = cls._lerobotdataset_to_transitions(
dataset=lerobot_dataset, state_keys=state_keys
)
# Fill the replay buffer with the lerobot dataset transitions
for data in list_transition:
replay_buffer.add(
@ -244,7 +248,9 @@ class ReplayBuffer:
# If not provided, you can either raise an error or define a default:
if state_keys is None:
raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.")
raise ValueError(
"You must provide a list of keys in `state_keys` that define your 'state'."
)
transitions: list[Transition] = []
num_frames = len(dataset)
@ -298,36 +304,40 @@ class ReplayBuffer:
# -- Build batched states --
batch_state = {}
for key in self.state_keys:
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to(
self.device
)
batch_state[key] = torch.cat(
[t["state"][key] for t in list_of_transitions], dim=0
).to(self.device)
if key.startswith("observation.image") and self.use_drq:
batch_state[key] = self.image_augmentation_function(batch_state[key])
# -- Build batched actions --
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device)
# -- Build batched rewards --
batch_rewards = torch.tensor([t["reward"] for t in list_of_transitions], dtype=torch.float32).to(
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(
self.device
)
# -- Build batched rewards --
batch_rewards = torch.tensor(
[t["reward"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
# -- Build batched next states --
batch_next_state = {}
for key in self.state_keys:
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to(
self.device
)
batch_next_state[key] = torch.cat(
[t["next_state"][key] for t in list_of_transitions], dim=0
).to(self.device)
if key.startswith("observation.image") and self.use_drq:
batch_next_state[key] = self.image_augmentation_function(batch_next_state[key])
batch_next_state[key] = self.image_augmentation_function(
batch_next_state[key]
)
# -- Build batched dones --
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
batch_dones = torch.tensor(
[t["done"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
batch_dones = torch.tensor(
[t["done"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
# Return a BatchTransition typed dict
return BatchTransition(
@ -344,7 +354,13 @@ def concatenate_batch_transitions(
) -> BatchTransition:
"""NOTE: Be careful it change the left_batch_transitions in place"""
left_batch_transitions["state"] = {
key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0)
key: torch.cat(
[
left_batch_transitions["state"][key],
right_batch_transition["state"][key],
],
dim=0,
)
for key in left_batch_transitions["state"]
}
left_batch_transitions["action"] = torch.cat(
@ -355,7 +371,11 @@ def concatenate_batch_transitions(
)
left_batch_transitions["next_state"] = {
key: torch.cat(
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0
[
left_batch_transitions["next_state"][key],
right_batch_transition["next_state"][key],
],
dim=0,
)
for key in left_batch_transitions["next_state"]
}
@ -407,7 +427,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online traning, we do not need dataset_stats
dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
if cfg.resume
else None,
device=device,
)
assert isinstance(policy, nn.Module)
@ -416,7 +438,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# TODO: Handle resume
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_learnable_params = sum(
p.numel() for p in policy.parameters() if p.requires_grad
)
num_total_params = sum(p.numel() for p in policy.parameters())
log_output_dir(out_dir)
@ -433,7 +457,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
obs = {key: obs[key].to(device, non_blocking=True) for key in obs}
replay_buffer = ReplayBuffer(
capacity=cfg.training.online_buffer_capacity, device=device, state_keys=cfg.policy.input_shapes.keys()
capacity=cfg.training.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
)
batch_size = cfg.training.batch_size
@ -455,12 +481,16 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
if interaction_step >= cfg.training.online_step_before_learning:
action = policy.select_action(batch=obs)
next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy())
next_obs, reward, done, truncated, info = online_env.step(
action.cpu().numpy()
)
else:
action = online_env.action_space.sample()
next_obs, reward, done, truncated, info = online_env.step(action)
# HACK
action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True)
action = torch.tensor(action, dtype=torch.float32).to(
device, non_blocking=True
)
# HACK: For maniskill
# next_obs = preprocess_observation(next_obs)
@ -470,14 +500,20 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Because we are using a single environment
# we can safely assume that the episode is done
if done[0] or truncated[0]:
logging.info(f"Global step {interaction_step}: Episode reward: {sum_reward_episode}")
logger.log_dict({"Sum episode reward": sum_reward_episode}, interaction_step)
logging.info(
f"Global step {interaction_step}: Episode reward: {sum_reward_episode}"
)
logger.log_dict(
{"Sum episode reward": sum_reward_episode}, interaction_step
)
sum_reward_episode = 0
# HACK: This is for maniskill
logging.info(
f"global step {interaction_step}: episode success: {info['success'].float().item()} \n"
)
logger.log_dict({"Episode success": info["success"].float().item()}, interaction_step)
logger.log_dict(
{"Episode success": info["success"].float().item()}, interaction_step
)
replay_buffer.add(
state=obs,
@ -551,7 +587,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
training_infos["loss_actor"] = loss_actor.item()
loss_temperature = policy.compute_loss_temperature(observations=observations)
loss_temperature = policy.compute_loss_temperature(
observations=observations
)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
optimizers["temperature"].step()
@ -573,7 +611,9 @@ def train_cli(cfg: dict):
)
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
def train_notebook(
out_dir=None, job_name=None, config_name="default", config_path="../configs"
):
from hydra import compose, initialize
hydra.core.global_hydra.GlobalHydra.instance().clear()

View File

@ -94,8 +94,12 @@ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
assert chw_float32_torch.dtype == torch.float32
assert chw_float32_torch.ndim == 3
c, h, w = chw_float32_torch.shape
assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}"
hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
assert (
c < h and c < w
), f"expect channel first images, but instead {chw_float32_torch.shape}"
hwc_uint8_numpy = (
(chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
)
return hwc_uint8_numpy

View File

@ -81,7 +81,11 @@ def run_server(
static_folder: Path,
template_folder: Path,
):
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
app = Flask(
__name__,
static_folder=static_folder.resolve(),
template_folder=template_folder.resolve(),
)
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
@app.route("/")
@ -138,8 +142,12 @@ def run_server(
)
)
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
@app.route(
"/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>"
)
def show_episode(
dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes
):
repo_id = f"{dataset_namespace}/{dataset_name}"
try:
if dataset is None:
@ -171,15 +179,21 @@ def run_server(
}
if isinstance(dataset, LeRobotDataset):
video_paths = [
dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
dataset.meta.get_video_file_path(episode_id, key)
for key in dataset.meta.video_keys
]
videos_info = [
{"url": url_for("static", filename=video_path), "filename": video_path.parent.name}
{
"url": url_for("static", filename=video_path),
"filename": video_path.parent.name,
}
for video_path in video_paths
]
tasks = dataset.meta.episodes[episode_id]["tasks"]
else:
video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"]
video_keys = [
key for key, ft in dataset.features.items() if ft["dtype"] == "video"
]
videos_info = [
{
"url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
@ -198,16 +212,24 @@ def run_server(
)
response.raise_for_status()
# Split into lines and parse each line as JSON
tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()]
tasks_jsonl = [
json.loads(line) for line in response.text.splitlines() if line.strip()
]
filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id]
filtered_tasks_jsonl = [
row for row in tasks_jsonl if row["episode_index"] == episode_id
]
tasks = filtered_tasks_jsonl[0]["tasks"]
videos_info[0]["language_instruction"] = tasks
if episodes is None:
episodes = list(
range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
range(
dataset.num_episodes
if isinstance(dataset, LeRobotDataset)
else dataset.total_episodes
)
)
return render_template(
@ -255,7 +277,10 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
else dataset.features[column_name].shape[0]
)
if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
if (
"names" in dataset.features[column_name]
and dataset.features[column_name]["names"]
):
column_names = dataset.features[column_name]["names"]
while not isinstance(column_names, list):
column_names = list(column_names.values())[0]
@ -278,8 +303,12 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
else:
repo_id = dataset.repo_id
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index
url = (
f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
+ dataset.data_path.format(
episode_chunk=int(episode_index) // dataset.chunks_size,
episode_index=episode_index,
)
)
df = pd.read_parquet(url)
data = df[selected_columns] # Select specific columns
@ -312,7 +341,9 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
]
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
def get_episode_language_instruction(
dataset: LeRobotDataset, ep_index: int
) -> list[str]:
# check if the dataset has language instructions
if "language_instruction" not in dataset.features:
return None
@ -323,7 +354,9 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) ->
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
# with the tf.tensor appearing in the string
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix(
"', shape=(), dtype=string)"
)
def get_dataset_info(repo_id: str) -> IterableNamespace:
@ -358,7 +391,9 @@ def visualize_dataset_html(
if force_override:
shutil.rmtree(output_dir)
else:
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
logging.info(
f"Output directory already exists. Loading from it: '{output_dir}'"
)
output_dir.mkdir(parents=True, exist_ok=True)

View File

@ -52,7 +52,13 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
# save 2 frames at the middle of first episode
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
i = int(
(
dataset.episode_data_index["to"][0].item()
- dataset.episode_data_index["from"][0].item()
)
/ 2
)
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")

View File

@ -30,7 +30,9 @@ class config: # noqa: N801
def enable_device(self, device_id: str):
self.device_enabled = device_id
def enable_stream(self, stream_type: stream, width=None, height=None, color_format=None, fps=None):
def enable_stream(
self, stream_type: stream, width=None, height=None, color_format=None, fps=None
):
self.stream_type = stream_type
# Overwrite default values when possible
self.width = 848 if width is None else width

View File

@ -37,7 +37,10 @@ pytest -sx 'tests/test_cameras.py::test_camera[intelrealsense-True]'
import numpy as np
import pytest
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
)
from tests.utils import TEST_CAMERA_TYPES, make_camera, require_camera
# Maximum absolute difference between two consecutive images recorded by a camera.
@ -112,7 +115,11 @@ def test_camera(request, camera_type, mock):
)
# TODO(rcadene): properly set `rtol`
np.testing.assert_allclose(
color_image, async_color_image, rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE, err_msg=error_msg
color_image,
async_color_image,
rtol=1e-5,
atol=MAX_PIXEL_DIFFERENCE,
err_msg=error_msg,
)
# Test disconnecting
@ -131,7 +138,11 @@ def test_camera(request, camera_type, mock):
assert camera.color_mode == "bgr"
bgr_color_image = camera.read()
np.testing.assert_allclose(
color_image, bgr_color_image[:, :, [2, 1, 0]], rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE, err_msg=error_msg
color_image,
bgr_color_image[:, :, [2, 1, 0]],
rtol=1e-5,
atol=MAX_PIXEL_DIFFERENCE,
err_msg=error_msg,
)
del camera
@ -166,7 +177,11 @@ def test_camera(request, camera_type, mock):
rot_color_image = camera.read()
np.testing.assert_allclose(
rot_color_image, manual_rot_img, rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE, err_msg=error_msg
rot_color_image,
manual_rot_img,
rtol=1e-5,
atol=MAX_PIXEL_DIFFERENCE,
err_msg=error_msg,
)
del camera
@ -200,7 +215,9 @@ def test_save_images_from_cameras(tmp_path, request, camera_type, mock):
if camera_type == "opencv":
from lerobot.common.robot_devices.cameras.opencv import save_images_from_cameras
elif camera_type == "intelrealsense":
from lerobot.common.robot_devices.cameras.intelrealsense import save_images_from_cameras
from lerobot.common.robot_devices.cameras.intelrealsense import (
save_images_from_cameras,
)
# Small `record_time_s` to speedup unit tests
save_images_from_cameras(tmp_path, record_time_s=0.02, mock=mock)

View File

@ -91,7 +91,12 @@ def patch_builtins_input(monkeypatch):
def pytest_addoption(parser):
parser.addoption("--seed", action="store", default="42", help="Set random seed for reproducibility")
parser.addoption(
"--seed",
action="store",
default="42",
help="Set random seed for reproducibility",
)
@pytest.fixture(autouse=True)

View File

@ -364,10 +364,16 @@ def test_save_each_transform(img_tensor_factory, tmp_path):
for transform in transforms:
transform_dir = tmp_path / transform
assert transform_dir.exists(), f"{transform} directory was not created."
assert any(transform_dir.iterdir()), f"No transformed images found in {transform} directory."
assert any(
transform_dir.iterdir()
), f"No transformed images found in {transform} directory."
# Check for specific files within each transform directory
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"]
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + [
"min.png",
"max.png",
"mean.png",
]
for file_name in expected_files:
assert (transform_dir / file_name).exists(), (
f"{file_name} was not found in {transform} directory."

View File

@ -187,7 +187,9 @@ def test_save_image_torch(tmp_path, img_tensor_factory):
writer.wait_until_done()
assert fpath.exists()
saved_image = np.array(Image.open(fpath))
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(
np.uint8
)
assert np.array_equal(expected_image, saved_image)
finally:
writer.stop()
@ -202,7 +204,9 @@ def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory):
writer.wait_until_done()
assert fpath.exists()
saved_image = np.array(Image.open(fpath))
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(
np.uint8
)
assert np.array_equal(expected_image, saved_image)
finally:
writer.stop()
@ -292,7 +296,9 @@ def test_wait_until_done(tmp_path, img_array_factory):
writer = AsyncImageWriter(num_processes=0, num_threads=4)
try:
num_images = 100
image_arrays = [img_array_factory(height=500, width=500) for _ in range(num_images)]
image_arrays = [
img_array_factory(height=500, width=500) for _ in range(num_images)
]
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
fpath.parent.mkdir(parents=True, exist_ok=True)

View File

@ -44,13 +44,23 @@ def make_new_buffer(
return buffer, write_dir
def make_spoof_data_frames(n_episodes: int, n_frames_per_episode: int) -> dict[str, np.ndarray]:
def make_spoof_data_frames(
n_episodes: int, n_frames_per_episode: int
) -> dict[str, np.ndarray]:
new_data = {
data_key: np.arange(n_frames_per_episode * n_episodes * np.prod(data_shape)).reshape(-1, *data_shape),
data_key: np.arange(
n_frames_per_episode * n_episodes * np.prod(data_shape)
).reshape(-1, *data_shape),
OnlineBuffer.INDEX_KEY: np.arange(n_frames_per_episode * n_episodes),
OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(np.arange(n_episodes), n_frames_per_episode),
OnlineBuffer.FRAME_INDEX_KEY: np.tile(np.arange(n_frames_per_episode), n_episodes),
OnlineBuffer.TIMESTAMP_KEY: np.tile(np.arange(n_frames_per_episode) / fps, n_episodes),
OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(
np.arange(n_episodes), n_frames_per_episode
),
OnlineBuffer.FRAME_INDEX_KEY: np.tile(
np.arange(n_frames_per_episode), n_episodes
),
OnlineBuffer.TIMESTAMP_KEY: np.tile(
np.arange(n_frames_per_episode) / fps, n_episodes
),
}
return new_data
@ -219,47 +229,72 @@ def test_compute_sampler_weights_trivial(
online_dataset_size: int,
online_sampling_ratio: float,
):
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=offline_dataset_size)
offline_dataset = lerobot_dataset_factory(
tmp_path, total_episodes=1, total_frames=offline_dataset_size
)
online_dataset, _ = make_new_buffer()
if online_dataset_size > 0:
online_dataset.add_data(
make_spoof_data_frames(n_episodes=2, n_frames_per_episode=online_dataset_size // 2)
make_spoof_data_frames(
n_episodes=2, n_frames_per_episode=online_dataset_size // 2
)
)
weights = compute_sampler_weights(
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
offline_dataset,
online_dataset=online_dataset,
online_sampling_ratio=online_sampling_ratio,
)
if offline_dataset_size == 0 or online_dataset_size == 0:
expected_weights = torch.ones(offline_dataset_size + online_dataset_size)
elif online_sampling_ratio == 0:
expected_weights = torch.cat([torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)])
expected_weights = torch.cat(
[torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)]
)
elif online_sampling_ratio == 1:
expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)])
expected_weights = torch.cat(
[torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)]
)
expected_weights /= expected_weights.sum()
torch.testing.assert_close(weights, expected_weights)
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path):
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
offline_dataset = lerobot_dataset_factory(
tmp_path, total_episodes=1, total_frames=4
)
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
online_dataset.add_data(
make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)
)
online_sampling_ratio = 0.8
weights = compute_sampler_weights(
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
offline_dataset,
online_dataset=online_dataset,
online_sampling_ratio=online_sampling_ratio,
)
torch.testing.assert_close(
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
)
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_dataset_factory, tmp_path):
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(
lerobot_dataset_factory, tmp_path
):
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
offline_dataset = lerobot_dataset_factory(
tmp_path, total_episodes=1, total_frames=4
)
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
online_dataset.add_data(
make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)
)
weights = compute_sampler_weights(
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1
offline_dataset,
online_dataset=online_dataset,
online_sampling_ratio=0.8,
online_drop_n_last_frames=1,
)
torch.testing.assert_close(
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0])
@ -268,9 +303,13 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_datase
def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path):
"""Note: test copied from test_sampler."""
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=2)
offline_dataset = lerobot_dataset_factory(
tmp_path, total_episodes=1, total_frames=2
)
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
online_dataset.add_data(
make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)
)
weights = compute_sampler_weights(
offline_dataset,

View File

@ -15,7 +15,9 @@
# limitations under the License.
from datasets import Dataset
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
)
from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import (
hf_transform_to_torch,

View File

@ -20,17 +20,39 @@ DUMMY_MOTOR_FEATURES = {
"action": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
"names": [
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
],
},
"state": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
"names": [
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
],
},
}
DUMMY_CAMERA_FEATURES = {
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
"laptop": {
"shape": (480, 640, 3),
"names": ["height", "width", "channels"],
"info": None,
},
"phone": {
"shape": (480, 640, 3),
"names": ["height", "width", "channels"],
"info": None,
},
}
DEFAULT_FPS = 30
DUMMY_VIDEO_INFO = {

View File

@ -23,7 +23,11 @@ import PIL.Image
import pytest
import torch
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.lerobot_dataset import (
CODEBASE_VERSION,
LeRobotDataset,
LeRobotDatasetMetadata,
)
from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_FEATURES,
@ -54,7 +58,9 @@ def get_task_index(task_dicts: dict, task: str) -> int:
@pytest.fixture(scope="session")
def img_tensor_factory():
def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor:
def _create_img_tensor(
height=100, width=100, channels=3, dtype=torch.float32
) -> torch.Tensor:
return torch.rand((channels, height, width), dtype=dtype)
return _create_img_tensor
@ -62,10 +68,14 @@ def img_tensor_factory():
@pytest.fixture(scope="session")
def img_array_factory():
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
def _create_img_array(
height=100, width=100, channels=3, dtype=np.uint8
) -> np.ndarray:
if np.issubdtype(dtype, np.unsignedinteger):
# Int array in [0, 255] range
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
img_array = np.random.randint(
0, 256, size=(height, width, channels), dtype=dtype
)
elif np.issubdtype(dtype, np.floating):
# Float array in [0, 1] range
img_array = np.random.rand(height, width, channels).astype(dtype)
@ -94,10 +104,13 @@ def features_factory():
) -> dict:
if use_videos:
camera_ft = {
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items()
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO}
for key, ft in camera_features.items()
}
else:
camera_ft = {key: {"dtype": "image", **ft} for key, ft in camera_features.items()}
camera_ft = {
key: {"dtype": "image", **ft} for key, ft in camera_features.items()
}
return {
**motor_features,
**camera_ft,
@ -215,7 +228,9 @@ def episodes_factory(tasks_factory):
if total_episodes <= 0 or total_frames <= 0:
raise ValueError("num_episodes and total_length must be positive integers.")
if total_frames < total_episodes:
raise ValueError("total_length must be greater than or equal to num_episodes.")
raise ValueError(
"total_length must be greater than or equal to num_episodes."
)
if not tasks:
min_tasks = 2 if multi_task else 1
@ -223,10 +238,14 @@ def episodes_factory(tasks_factory):
tasks = tasks_factory(total_tasks)
if total_episodes < len(tasks) and not multi_task:
raise ValueError("The number of tasks should be less than the number of episodes.")
raise ValueError(
"The number of tasks should be less than the number of episodes."
)
# Generate random lengths that sum up to total_length
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
lengths = np.random.multinomial(
total_frames, [1 / total_episodes] * total_episodes
).tolist()
tasks_list = [task_dict["task"] for task_dict in tasks.values()]
num_tasks_available = len(tasks_list)
@ -234,9 +253,13 @@ def episodes_factory(tasks_factory):
episodes = {}
remaining_tasks = tasks_list.copy()
for ep_idx in range(total_episodes):
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
num_tasks_in_episode = (
random.randint(1, min(3, num_tasks_available)) if multi_task else 1
)
tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
episode_tasks = random.sample(
tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample))
)
if remaining_tasks:
for task in episode_tasks:
remaining_tasks.remove(task)
@ -253,7 +276,9 @@ def episodes_factory(tasks_factory):
@pytest.fixture(scope="session")
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
def hf_dataset_factory(
features_factory, tasks_factory, episodes_factory, img_array_factory
):
def _create_hf_dataset(
features: dict | None = None,
tasks: list[dict] | None = None,
@ -275,10 +300,15 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
episode_index_col = np.concatenate(
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
(
episode_index_col,
np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int),
)
)
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
task_index = np.concatenate(
(task_index, np.full(ep_dict["length"], ep_task_index, dtype=int))
)
index_col = np.arange(len(episode_index_col))
@ -290,7 +320,9 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
for _ in range(len(index_col))
]
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"])
robot_cols[key] = np.random.random(
(len(index_col), ft["shape"][0])
).astype(ft["dtype"])
hf_features = get_hf_features_from_features(features)
dataset = datasets.Dataset.from_dict(
@ -340,7 +372,9 @@ def lerobot_dataset_metadata_factory(
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
tasks=tasks,
)
mock_snapshot_download = mock_snapshot_download_factory(
@ -392,7 +426,9 @@ def lerobot_dataset_factory(
) -> LeRobotDataset:
if not info:
info = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
total_episodes=total_episodes,
total_frames=total_frames,
total_tasks=total_tasks,
)
if not stats:
stats = stats_factory(features=info["features"])
@ -408,7 +444,9 @@ def lerobot_dataset_factory(
multi_task=multi_task,
)
if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"])
hf_dataset = hf_dataset_factory(
tasks=tasks, episodes=episode_dicts, fps=info["fps"]
)
mock_snapshot_download = mock_snapshot_download_factory(
info=info,

View File

@ -102,7 +102,10 @@ def episode_path(episodes_factory):
@pytest.fixture(scope="session")
def single_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_single_episode_parquet(
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
dir: Path,
ep_idx: int = 0,
hf_dataset: datasets.Dataset | None = None,
info: dict | None = None,
) -> Path:
if not info:
info = info_factory()

24
tests/fixtures/hub.py vendored
View File

@ -67,15 +67,21 @@ def mock_snapshot_download_factory(
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
tasks=tasks,
)
if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
hf_dataset = hf_dataset_factory(
tasks=tasks, episodes=episodes, fps=info["fps"]
)
def _extract_episode_index_from_path(fpath: str) -> int:
path = Path(fpath)
if path.suffix == ".parquet" and path.stem.startswith("episode_"):
episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0
episode_index = int(
path.stem[len("episode_") :]
) # 'episode_000000' -> 0
return episode_index
else:
return None
@ -100,12 +106,16 @@ def mock_snapshot_download_factory(
for episode_dict in episodes.values():
ep_idx = episode_dict["episode_index"]
ep_chunk = ep_idx // info["chunks_size"]
data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
data_path = info["data_path"].format(
episode_chunk=ep_chunk, episode_index=ep_idx
)
data_files.append(data_path)
all_files.extend(data_files)
allowed_files = filter_repo_objects(
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
all_files,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
# Create allowed files
@ -113,7 +123,9 @@ def mock_snapshot_download_factory(
if rel_path.startswith("data/"):
episode_index = _extract_episode_index_from_path(rel_path)
if episode_index is not None:
_ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info)
_ = single_episode_parquet_path(
local_dir, episode_index, hf_dataset, info
)
if rel_path == INFO_PATH:
_ = info_path(local_dir, info)
elif rel_path == STATS_PATH:

View File

@ -80,7 +80,9 @@ class GroupSyncRead:
def addParam(self, motor_index): # noqa: N802
# Initialize motor default values
if motor_index not in self.packet_handler.data:
self.packet_handler.data[motor_index] = get_default_motor_values(motor_index)
self.packet_handler.data[motor_index] = get_default_motor_values(
motor_index
)
def txRxPacket(self): # noqa: N802
return COMM_SUCCESS

View File

@ -91,7 +91,9 @@ class GroupSyncRead:
def addParam(self, motor_index): # noqa: N802
# Initialize motor default values
if motor_index not in self.packet_handler.data:
self.packet_handler.data[motor_index] = get_default_motor_values(motor_index)
self.packet_handler.data[motor_index] = get_default_motor_values(
motor_index
)
def txRxPacket(self): # noqa: N802
return COMM_SUCCESS

View File

@ -43,7 +43,10 @@ import time
import numpy as np
import pytest
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
)
from lerobot.scripts.find_motors_bus_port import find_port
from tests.utils import TEST_MOTOR_TYPES, make_motors_bus, require_motor
@ -76,7 +79,9 @@ def test_configure_motors_all_ids_1(request, motor_type, mock):
else:
raise ValueError(motor_type)
input("Are you sure you want to re-configure the motors? Press enter to continue...")
input(
"Are you sure you want to re-configure the motors? Press enter to continue..."
)
# This test expect the configuration was already correct.
motors_bus = make_motors_bus(motor_type, mock=mock)
motors_bus.connect()

View File

@ -25,7 +25,10 @@ from torchmetrics import AUROC, Accuracy, F1Score, Precision, Recall
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier, ClassifierConfig
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
ClassifierConfig,
)
BATCH_SIZE = 1000
LR = 0.1
@ -43,7 +46,9 @@ def train_evaluate_multiclass_classifier():
logging.info(
f"Start multiclass classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}"
)
multiclass_config = ClassifierConfig(model_name="microsoft/resnet-18", device=DEVICE, num_classes=10)
multiclass_config = ClassifierConfig(
model_name="microsoft/resnet-18", device=DEVICE, num_classes=10
)
multiclass_classifier = Classifier(multiclass_config)
trainset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
@ -114,10 +119,18 @@ def train_evaluate_multiclass_classifier():
test_probs = torch.stack(test_probs)
accuracy = Accuracy(task="multiclass", num_classes=multiclass_num_classes)
precision = Precision(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
recall = Recall(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
f1 = F1Score(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
auroc = AUROC(task="multiclass", num_classes=multiclass_num_classes, average="weighted")
precision = Precision(
task="multiclass", average="weighted", num_classes=multiclass_num_classes
)
recall = Recall(
task="multiclass", average="weighted", num_classes=multiclass_num_classes
)
f1 = F1Score(
task="multiclass", average="weighted", num_classes=multiclass_num_classes
)
auroc = AUROC(
task="multiclass", num_classes=multiclass_num_classes, average="weighted"
)
# Calculate metrics
acc = accuracy(test_predictions, test_labels)
@ -146,18 +159,28 @@ def train_evaluate_binary_classifier():
new_label = float(1.0) if label == target_class else float(0.0)
new_targets.append(new_label)
dataset.targets = new_targets # Replace the original labels with the binary ones
dataset.targets = (
new_targets # Replace the original labels with the binary ones
)
return dataset
binary_train_dataset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
binary_test_dataset = CIFAR10(root="data", train=False, download=True, transform=ToTensor())
binary_train_dataset = CIFAR10(
root="data", train=True, download=True, transform=ToTensor()
)
binary_test_dataset = CIFAR10(
root="data", train=False, download=True, transform=ToTensor()
)
# Apply one-vs-rest labeling
binary_train_dataset = one_vs_rest(binary_train_dataset, target_binary_class)
binary_test_dataset = one_vs_rest(binary_test_dataset, target_binary_class)
binary_trainloader = DataLoader(binary_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
binary_testloader = DataLoader(binary_test_dataset, batch_size=BATCH_SIZE, shuffle=False)
binary_trainloader = DataLoader(
binary_train_dataset, batch_size=BATCH_SIZE, shuffle=True
)
binary_testloader = DataLoader(
binary_test_dataset, batch_size=BATCH_SIZE, shuffle=False
)
binary_epoch = 1

View File

@ -9,7 +9,9 @@ from tests.utils import require_package
def test_classifier_output():
output = ClassifierOutput(
logits=torch.tensor([1, 2, 3]), probabilities=torch.tensor([0.1, 0.2, 0.3]), hidden_states=None
logits=torch.tensor([1, 2, 3]),
probabilities=torch.tensor([0.1, 0.2, 0.3]),
hidden_states=None,
)
assert (
@ -20,7 +22,9 @@ def test_classifier_output():
@require_package("transformers")
def test_binary_classifier_with_default_params():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
)
config = ClassifierConfig()
classifier = Classifier(config)
@ -41,7 +45,9 @@ def test_binary_classifier_with_default_params():
@require_package("transformers")
def test_multiclass_classifier():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
)
num_classes = 5
config = ClassifierConfig(num_classes=num_classes)
@ -63,7 +69,9 @@ def test_multiclass_classifier():
@require_package("transformers")
def test_default_device():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
)
config = ClassifierConfig()
assert config.device == "cpu"
@ -75,7 +83,9 @@ def test_default_device():
@require_package("transformers")
def test_explicit_device_setup():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
)
config = ClassifierConfig(device="meta")
assert config.device == "meta"

View File

@ -172,7 +172,9 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
# Test updating the policy (and test that it does not mutate the batch)
batch_ = deepcopy(batch)
policy.forward(batch)
assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass."
assert set(batch) == set(
batch_
), "Batch keys are not the same after a forward pass."
assert all(
torch.equal(batch[k], batch_[k]) if isinstance(batch[k], torch.Tensor) else batch[k] == batch_[k]
for k in batch
@ -186,7 +188,9 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
observation = preprocess_observation(observation)
# send observation to device/gpu
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
observation = {
key: observation[key].to(DEVICE, non_blocking=True) for key in observation
}
# get the next action for the environment (also check that the observation batch is not modified)
observation_ = deepcopy(observation)
@ -452,7 +456,9 @@ def test_act_temporal_ensembler():
batch_size = batch_seq.shape[0]
# Exponential weighting (normalized). Unsqueeze once to match the position of the `episode_length`
# dimension of `batch_seq`.
weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(-1)
weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(
-1
)
# Simulate stepping through a rollout and computing a batch of actions with model on each step.
for i in range(episode_length):
@ -475,7 +481,8 @@ def test_act_temporal_ensembler():
episode_step_indices = torch.arange(i + 1)[-len(chunk_indices) :]
seq_slice = batch_seq[:, episode_step_indices, chunk_indices]
offline_avg = (
einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum") / weights[: i + 1].sum()
einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum")
/ weights[: i + 1].sum()
)
# Sanity check. The average should be between the extrema.
assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg)

View File

@ -335,8 +335,12 @@ def test_record_with_event_rerecord_episode(tmp_path, request, robot_type, mock)
)
dataset = record(robot, rec_cfg)
assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False"
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
assert not mock_events[
"rerecord_episode"
], "`rerecord_episode` wasn't properly reset to False"
assert not mock_events[
"exit_early"
], "`exit_early` wasn't properly reset to False"
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
@ -391,7 +395,8 @@ def test_record_with_event_exit_early(tmp_path, request, robot_type, mock):
@pytest.mark.parametrize(
"robot_type, mock, num_image_writer_processes", [("koch", True, 0), ("koch", True, 1)]
"robot_type, mock, num_image_writer_processes",
[("koch", True, 0), ("koch", True, 1)],
)
@require_robot
def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, num_image_writer_processes):

View File

@ -105,7 +105,9 @@ def test_robot(tmp_path, request, robot_type, mock):
assert "observation.state" in observation
assert isinstance(observation["observation.state"], torch.Tensor)
assert observation["observation.state"].ndim == 1
dim_state = sum(len(robot.follower_arms[name].motors) for name in robot.follower_arms)
dim_state = sum(
len(robot.follower_arms[name].motors) for name in robot.follower_arms
)
assert observation["observation.state"].shape[0] == dim_state
# Cameras
for name in robot.cameras:
@ -116,7 +118,9 @@ def test_robot(tmp_path, request, robot_type, mock):
assert "action" in action
assert isinstance(action["action"], torch.Tensor)
assert action["action"].ndim == 1
dim_action = sum(len(robot.follower_arms[name].motors) for name in robot.follower_arms)
dim_action = sum(
len(robot.follower_arms[name].motors) for name in robot.follower_arms
)
assert action["action"].shape[0] == dim_action
# TODO(rcadene): test if observation and action data are returned as expected

View File

@ -9,7 +9,9 @@ from hydra import compose, initialize_config_dir
from torch import nn
from torch.utils.data import Dataset
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
ClassifierConfig,
)
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.scripts.train_hilserl_classifier import (
create_balanced_sampler,
@ -34,7 +36,9 @@ class MockDataset(Dataset):
def make_dummy_model():
model_config = ClassifierConfig(
num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=1
num_classes=2,
model_name="hf-tiny-model-private/tiny-random-ResNetModel",
num_cameras=1,
)
model = Classifier(config=model_config)
return model
@ -65,7 +69,9 @@ def test_create_balanced_sampler():
labels = [item["label"] for item in data]
class_counts = torch.tensor([labels.count(0), labels.count(1)], dtype=torch.float32)
class_weights = 1.0 / class_counts
expected_weights = torch.tensor([class_weights[label] for label in labels], dtype=torch.float32)
expected_weights = torch.tensor(
[class_weights[label] for label in labels], dtype=torch.float32
)
# Test that the weights are correct
assert torch.allclose(weights, expected_weights)
@ -149,7 +155,9 @@ def test_validate():
def test_train_epoch_multiple_cameras():
model_config = ClassifierConfig(
num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=2
num_classes=2,
model_name="hf-tiny-model-private/tiny-random-ResNetModel",
num_cameras=2,
)
model = Classifier(config=model_config)
@ -216,10 +224,16 @@ def test_resume_function(
):
# Initialize Hydra
test_file_dir = os.path.dirname(os.path.abspath(__file__))
config_dir = os.path.abspath(os.path.join(test_file_dir, "..", "lerobot", "configs", "policy"))
assert os.path.exists(config_dir), f"Config directory does not exist at {config_dir}"
config_dir = os.path.abspath(
os.path.join(test_file_dir, "..", "lerobot", "configs", "policy")
)
assert os.path.exists(
config_dir
), f"Config directory does not exist at {config_dir}"
with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"):
with initialize_config_dir(
config_dir=config_dir, job_name="test_app", version_base="1.2"
):
cfg = compose(
config_name="hilserl_classifier",
overrides=[
@ -244,7 +258,9 @@ def test_resume_function(
mock_init_hydra_config.return_value = cfg
# Mock dataset
dataset = MockDataset([{"image": torch.rand(3, 224, 224), "label": i % 2} for i in range(10)])
dataset = MockDataset(
[{"image": torch.rand(3, 224, 224), "label": i % 2} for i in range(10)]
)
mock_dataset.return_value = dataset
# Mock checkpoint handling

View File

@ -47,7 +47,9 @@ for motor_type in available_motors:
OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0))
INTELREALSENSE_SERIAL_NUMBER = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614))
DYNAMIXEL_PORT = os.environ.get("LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081")
DYNAMIXEL_PORT = os.environ.get(
"LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081"
)
DYNAMIXEL_MOTORS = {
"shoulder_pan": [1, "xl430-w250"],
"shoulder_lift": [2, "xl430-w250"],
@ -57,7 +59,9 @@ DYNAMIXEL_MOTORS = {
"gripper": [6, "xl330-m288"],
}
FEETECH_PORT = os.environ.get("LEROBOT_TEST_FEETECH_PORT", "/dev/tty.usbmodem585A0080971")
FEETECH_PORT = os.environ.get(
"LEROBOT_TEST_FEETECH_PORT", "/dev/tty.usbmodem585A0080971"
)
FEETECH_MOTORS = {
"shoulder_pan": [1, "sts3215"],
"shoulder_lift": [2, "sts3215"],
@ -156,9 +160,13 @@ def require_package_arg(func):
if "required_packages" in arg_names:
# Get the index of 'required_packages' and retrieve the value from args
index = arg_names.index("required_packages")
required_packages = args[index] if len(args) > index else kwargs.get("required_packages")
required_packages = (
args[index] if len(args) > index else kwargs.get("required_packages")
)
else:
raise ValueError("Function does not have 'required_packages' as an argument.")
raise ValueError(
"Function does not have 'required_packages' as an argument."
)
if required_packages is None:
return func(*args, **kwargs)
@ -215,11 +223,17 @@ def require_robot(func):
mock = kwargs.get("mock")
if robot_type is None:
raise ValueError("The 'robot_type' must be an argument of the test function.")
raise ValueError(
"The 'robot_type' must be an argument of the test function."
)
if request is None:
raise ValueError("The 'request' fixture must be an argument of the test function.")
raise ValueError(
"The 'request' fixture must be an argument of the test function."
)
if mock is None:
raise ValueError("The 'mock' variable must be an argument of the test function.")
raise ValueError(
"The 'mock' variable must be an argument of the test function."
)
# Run test with a real robot. Skip test if robot connection fails.
if not mock and not request.getfixturevalue("is_robot_available"):
@ -239,11 +253,17 @@ def require_camera(func):
mock = kwargs.get("mock")
if request is None:
raise ValueError("The 'request' fixture must be an argument of the test function.")
raise ValueError(
"The 'request' fixture must be an argument of the test function."
)
if camera_type is None:
raise ValueError("The 'camera_type' must be an argument of the test function.")
raise ValueError(
"The 'camera_type' must be an argument of the test function."
)
if mock is None:
raise ValueError("The 'mock' variable must be an argument of the test function.")
raise ValueError(
"The 'mock' variable must be an argument of the test function."
)
if not mock and not request.getfixturevalue("is_camera_available"):
pytest.skip(f"A {camera_type} camera is not available.")
@ -262,11 +282,17 @@ def require_motor(func):
mock = kwargs.get("mock")
if request is None:
raise ValueError("The 'request' fixture must be an argument of the test function.")
raise ValueError(
"The 'request' fixture must be an argument of the test function."
)
if motor_type is None:
raise ValueError("The 'motor_type' must be an argument of the test function.")
raise ValueError(
"The 'motor_type' must be an argument of the test function."
)
if mock is None:
raise ValueError("The 'mock' variable must be an argument of the test function.")
raise ValueError(
"The 'mock' variable must be an argument of the test function."
)
if not mock and not request.getfixturevalue("is_motor_available"):
pytest.skip(f"A {motor_type} motor is not available.")
@ -285,7 +311,14 @@ def mock_calibration_dir(calibration_dir):
"start_pos": [1442, 843, 2166, 2849, 1988, 1835],
"end_pos": [2440, 1869, -1106, -1848, -926, 3235],
"calib_mode": ["DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "LINEAR"],
"motor_names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
"motor_names": [
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
],
}
Path(str(calibration_dir)).mkdir(parents=True, exist_ok=True)
with open(calibration_dir / "main_follower.json", "w") as f: