[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-03-04 13:38:47 +00:00 committed by AdilZouitine
parent 76df8a31b3
commit 38f5fa4523
79 changed files with 2782 additions and 788 deletions

View File

@ -32,7 +32,11 @@ import numpy as np
import pandas as pd import pandas as pd
import PIL import PIL
import torch import torch
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity from skimage.metrics import (
mean_squared_error,
peak_signal_noise_ratio,
structural_similarity,
)
from tqdm import tqdm from tqdm import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
@ -81,7 +85,9 @@ def get_directory_size(directory: Path) -> int:
return total_size return total_size
def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> torch.Tensor: def load_original_frames(
imgs_dir: Path, timestamps: list[float], fps: int
) -> torch.Tensor:
frames = [] frames = []
for ts in timestamps: for ts in timestamps:
idx = int(ts * fps) idx = int(ts * fps)
@ -94,7 +100,11 @@ def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> t
def save_decoded_frames( def save_decoded_frames(
imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int imgs_dir: Path,
save_dir: Path,
frames: torch.Tensor,
timestamps: list[float],
fps: int,
) -> None: ) -> None:
if save_dir.exists() and len(list(save_dir.glob("frame_*.png"))) == len(timestamps): if save_dir.exists() and len(list(save_dir.glob("frame_*.png"))) == len(timestamps):
return return
@ -104,7 +114,10 @@ def save_decoded_frames(
idx = int(ts * fps) idx = int(ts * fps)
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy() frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame_{idx:06d}_decoded.png") PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame_{idx:06d}_decoded.png")
shutil.copyfile(imgs_dir / f"frame_{idx:06d}.png", save_dir / f"frame_{idx:06d}_original.png") shutil.copyfile(
imgs_dir / f"frame_{idx:06d}.png",
save_dir / f"frame_{idx:06d}_original.png",
)
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None: def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
@ -116,11 +129,17 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
hf_dataset = dataset.hf_dataset.with_format(None) hf_dataset = dataset.hf_dataset.with_format(None)
# We only save images from the first camera # We only save images from the first camera
img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")] img_keys = [
key for key in hf_dataset.features if key.startswith("observation.image")
]
imgs_dataset = hf_dataset.select_columns(img_keys[0]) imgs_dataset = hf_dataset.select_columns(img_keys[0])
for i, item in enumerate( for i, item in enumerate(
tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False) tqdm(
imgs_dataset,
desc=f"saving {dataset.repo_id} first episode images",
leave=False,
)
): ):
img = item[img_keys[0]] img = item[img_keys[0]]
img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100) img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100)
@ -129,7 +148,9 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
break break
def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> list[float]: def sample_timestamps(
timestamps_mode: str, ep_num_images: int, fps: int
) -> list[float]:
# Start at 5 to allow for 2_frames_4_space and 6_frames # Start at 5 to allow for 2_frames_4_space and 6_frames
idx = random.randint(5, ep_num_images - 1) idx = random.randint(5, ep_num_images - 1)
match timestamps_mode: match timestamps_mode:
@ -154,7 +175,9 @@ def decode_video_frames(
backend: str, backend: str,
) -> torch.Tensor: ) -> torch.Tensor:
if backend in ["pyav", "video_reader"]: if backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend) return decode_video_frames_torchvision(
video_path, timestamps, tolerance_s, backend
)
else: else:
raise NotImplementedError(backend) raise NotImplementedError(backend)
@ -181,7 +204,9 @@ def benchmark_decoding(
} }
with time_benchmark: with time_benchmark:
frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend) frames = decode_video_frames(
video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend
)
result["load_time_video_ms"] = time_benchmark.result_ms / num_frames result["load_time_video_ms"] = time_benchmark.result_ms / num_frames
with time_benchmark: with time_benchmark:
@ -190,12 +215,18 @@ def benchmark_decoding(
frames_np, original_frames_np = frames.numpy(), original_frames.numpy() frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
for i in range(num_frames): for i in range(num_frames):
result["mse_values"].append(mean_squared_error(original_frames_np[i], frames_np[i])) result["mse_values"].append(
mean_squared_error(original_frames_np[i], frames_np[i])
)
result["psnr_values"].append( result["psnr_values"].append(
peak_signal_noise_ratio(original_frames_np[i], frames_np[i], data_range=1.0) peak_signal_noise_ratio(
original_frames_np[i], frames_np[i], data_range=1.0
)
) )
result["ssim_values"].append( result["ssim_values"].append(
structural_similarity(original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0) structural_similarity(
original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0
)
) )
if save_frames and sample == 0: if save_frames and sample == 0:
@ -215,7 +246,9 @@ def benchmark_decoding(
# As these samples are independent, we run them in parallel threads to speed up the benchmark. # As these samples are independent, we run them in parallel threads to speed up the benchmark.
with ThreadPoolExecutor(max_workers=num_workers) as executor: with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(process_sample, i) for i in range(num_samples)] futures = [executor.submit(process_sample, i) for i in range(num_samples)]
for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False): for future in tqdm(
as_completed(futures), total=num_samples, desc="samples", leave=False
):
result = future.result() result = future.result()
load_times_video_ms.append(result["load_time_video_ms"]) load_times_video_ms.append(result["load_time_video_ms"])
load_times_images_ms.append(result["load_time_images_ms"]) load_times_images_ms.append(result["load_time_images_ms"])
@ -275,9 +308,13 @@ def benchmark_encoding_decoding(
random.seed(seed) random.seed(seed)
benchmark_table = [] benchmark_table = []
for timestamps_mode in tqdm( for timestamps_mode in tqdm(
decoding_cfg["timestamps_modes"], desc="decodings (timestamps_modes)", leave=False decoding_cfg["timestamps_modes"],
desc="decodings (timestamps_modes)",
leave=False,
): ):
for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False): for backend in tqdm(
decoding_cfg["backends"], desc="decodings (backends)", leave=False
):
benchmark_row = benchmark_decoding( benchmark_row = benchmark_decoding(
imgs_dir, imgs_dir,
video_path, video_path,
@ -355,14 +392,23 @@ def main(
imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_") imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_")
# We only use the first episode # We only use the first episode
save_first_episode(imgs_dir, dataset) save_first_episode(imgs_dir, dataset)
for key, values in tqdm(encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False): for key, values in tqdm(
encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False
):
for value in tqdm(values, desc=f"encodings ({key})", leave=False): for value in tqdm(values, desc=f"encodings ({key})", leave=False):
encoding_cfg = BASE_ENCODING.copy() encoding_cfg = BASE_ENCODING.copy()
encoding_cfg["vcodec"] = video_codec encoding_cfg["vcodec"] = video_codec
encoding_cfg["pix_fmt"] = pixel_format encoding_cfg["pix_fmt"] = pixel_format
encoding_cfg[key] = value encoding_cfg[key] = value
args_path = Path("_".join(str(value) for value in encoding_cfg.values())) args_path = Path(
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4" "_".join(str(value) for value in encoding_cfg.values())
)
video_path = (
output_dir
/ "videos"
/ args_path
/ f"{repo_id.replace('/', '_')}.mp4"
)
benchmark_table += benchmark_encoding_decoding( benchmark_table += benchmark_encoding_decoding(
dataset, dataset,
video_path, video_path,
@ -388,7 +434,9 @@ def main(
# Concatenate all results # Concatenate all results
df_list = [pd.read_csv(csv_path) for csv_path in file_paths] df_list = [pd.read_csv(csv_path) for csv_path in file_paths]
concatenated_df = pd.concat(df_list, ignore_index=True) concatenated_df = pd.concat(df_list, ignore_index=True)
concatenated_path = output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv" concatenated_path = (
output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
)
concatenated_df.to_csv(concatenated_path, header=True, index=False) concatenated_df.to_csv(concatenated_path, header=True, index=False)

View File

@ -32,7 +32,10 @@ import torch
from huggingface_hub import HfApi from huggingface_hub import HfApi
import lerobot import lerobot
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
LeRobotDatasetMetadata,
)
# We ported a number of existing datasets ourselves, use this to see the list: # We ported a number of existing datasets ourselves, use this to see the list:
print("List of available datasets:") print("List of available datasets:")
@ -40,7 +43,10 @@ pprint(lerobot.available_datasets)
# You can also browse through the datasets created/ported by the community on the hub using the hub api: # You can also browse through the datasets created/ported by the community on the hub using the hub api:
hub_api = HfApi() hub_api = HfApi()
repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])] repo_ids = [
info.id
for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])
]
pprint(repo_ids) pprint(repo_ids)
# Or simply explore them in your web browser directly at: # Or simply explore them in your web browser directly at:
@ -55,7 +61,9 @@ ds_meta = LeRobotDatasetMetadata(repo_id)
# structure of the dataset without downloading the actual data yet (only metadata files — which are # structure of the dataset without downloading the actual data yet (only metadata files — which are
# lightweight). # lightweight).
print(f"Total number of episodes: {ds_meta.total_episodes}") print(f"Total number of episodes: {ds_meta.total_episodes}")
print(f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}") print(
f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}"
)
print(f"Frames per second used during data collection: {ds_meta.fps}") print(f"Frames per second used during data collection: {ds_meta.fps}")
print(f"Robot type: {ds_meta.robot_type}") print(f"Robot type: {ds_meta.robot_type}")
print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n") print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n")

View File

@ -48,10 +48,14 @@ transforms = v2.Compose(
) )
# Create another LeRobotDataset with the defined transformations # Create another LeRobotDataset with the defined transformations
transformed_dataset = LeRobotDataset(dataset_repo_id, episodes=[0], image_transforms=transforms) transformed_dataset = LeRobotDataset(
dataset_repo_id, episodes=[0], image_transforms=transforms
)
# Get a frame from the transformed dataset # Get a frame from the transformed dataset
transformed_frame = transformed_dataset[first_idx][transformed_dataset.meta.camera_keys[0]] transformed_frame = transformed_dataset[first_idx][
transformed_dataset.meta.camera_keys[0]
]
# Create a directory to store output images # Create a directory to store output images
output_dir = Path("outputs/image_transforms") output_dir = Path("outputs/image_transforms")

View File

@ -26,7 +26,10 @@ import math
import torch import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
LeRobotDatasetMetadata,
)
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy

View File

@ -0,0 +1,228 @@
import shutil
from pathlib import Path
import numpy as np
import torch
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface."
PUSHT_FEATURES = {
"observation.state": {
"dtype": "float32",
"shape": (2,),
"names": {
"axes": ["x", "y"],
},
},
"action": {
"dtype": "float32",
"shape": (2,),
"names": {
"axes": ["x", "y"],
},
},
"next.reward": {
"dtype": "float32",
"shape": (1,),
"names": None,
},
"next.success": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
"observation.environment_state": {
"dtype": "float32",
"shape": (16,),
"names": [
"keypoints",
],
},
"observation.image": {
"dtype": None,
"shape": (3, 96, 96),
"names": [
"channel",
"height",
"width",
],
},
}
def build_features(mode: str) -> dict:
features = PUSHT_FEATURES
if mode == "keypoints":
features.pop("observation.image")
else:
features.pop("observation.environment_state")
features["observation.image"]["dtype"] = mode
return features
def load_raw_dataset(zarr_path: Path):
try:
from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import (
ReplayBuffer as DiffusionPolicyReplayBuffer,
)
except ModuleNotFoundError as e:
print(
"`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`"
)
raise e
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
return zarr_data
def calculate_coverage(zarr_data):
try:
import pymunk
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
except ModuleNotFoundError as e:
print(
"`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`"
)
raise e
block_pos = zarr_data["state"][:, 2:4]
block_angle = zarr_data["state"][:, 4]
num_frames = len(block_pos)
coverage = np.zeros((num_frames,))
# 8 keypoints with 2 coords each
keypoints = np.zeros((num_frames, 16))
# Set x, y, theta (in radians)
goal_pos_angle = np.array([256, 256, np.pi / 4])
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
for i in range(num_frames):
space = pymunk.Space()
space.gravity = 0, 0
space.damping = 0
# Add walls.
walls = [
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
]
space.add(*walls)
block_body, block_shapes = PushTEnv.add_tee(
space, block_pos[i].tolist(), block_angle[i].item()
)
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
intersection_area = goal_geom.intersection(block_geom).area
goal_area = goal_geom.area
coverage[i] = intersection_area / goal_area
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
return coverage, keypoints
def calculate_success(coverage: float, success_threshold: float):
return coverage > success_threshold
def calculate_reward(coverage: float, success_threshold: float):
return np.clip(coverage / success_threshold, 0, 1)
def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = True):
if mode not in ["video", "image", "keypoints"]:
raise ValueError(mode)
if (LEROBOT_HOME / repo_id).exists():
shutil.rmtree(LEROBOT_HOME / repo_id)
if not raw_dir.exists():
download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")
zarr_data = load_raw_dataset(zarr_path=raw_dir / "pusht_cchi_v7_replay.zarr")
env_state = zarr_data["state"][:]
agent_pos = env_state[:, :2]
action = zarr_data["action"][:]
image = zarr_data["img"] # (b, h, w, c)
episode_data_index = {
"from": np.concatenate(([0], zarr_data.meta["episode_ends"][:-1])),
"to": zarr_data.meta["episode_ends"],
}
# Calculate success and reward based on the overlapping area
# of the T-object and the T-area.
coverage, keypoints = calculate_coverage(zarr_data)
success = calculate_success(coverage, success_threshold=0.95)
reward = calculate_reward(coverage, success_threshold=0.95)
features = build_features(mode)
dataset = LeRobotDataset.create(
repo_id=repo_id,
fps=10,
robot_type="2d pointer",
features=features,
image_writer_threads=4,
)
episodes = range(len(episode_data_index["from"]))
for ep_idx in episodes:
from_idx = episode_data_index["from"][ep_idx]
to_idx = episode_data_index["to"][ep_idx]
num_frames = to_idx - from_idx
for frame_idx in range(num_frames):
i = from_idx + frame_idx
frame = {
"action": torch.from_numpy(action[i]),
# Shift reward and success by +1 until the last item of the episode
"next.reward": reward[i + (frame_idx < num_frames - 1)],
"next.success": success[i + (frame_idx < num_frames - 1)],
}
frame["observation.state"] = torch.from_numpy(agent_pos[i])
if mode == "keypoints":
frame["observation.environment_state"] = torch.from_numpy(keypoints[i])
else:
frame["observation.image"] = torch.from_numpy(image[i])
dataset.add_frame(frame)
dataset.save_episode(task=PUSHT_TASK)
dataset.consolidate()
if push_to_hub:
dataset.push_to_hub()
if __name__ == "__main__":
# To try this script, modify the repo id with your own HuggingFace user (e.g cadene/pusht)
repo_id = "lerobot/pusht"
modes = ["video", "image", "keypoints"]
# Uncomment if you want to try with a specific mode
# modes = ["video"]
# modes = ["image"]
# modes = ["keypoints"]
raw_dir = Path("data/lerobot-raw/pusht_raw")
for mode in modes:
if mode in ["image", "keypoints"]:
repo_id += f"_{mode}"
# download and load raw dataset, create LeRobotDataset, populate it, push to hub
main(raw_dir, repo_id=repo_id, mode=mode)
# Uncomment if you want to load the local dataset and explore it
# dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True)
# breakpoint()

View File

@ -164,7 +164,11 @@ available_real_world_datasets = [
] ]
available_datasets = sorted( available_datasets = sorted(
set(itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)) set(
itertools.chain(
*available_datasets_per_env.values(), available_real_world_datasets
)
)
) )
# lists all available policies from `lerobot/common/policies` # lists all available policies from `lerobot/common/policies`
@ -205,9 +209,13 @@ available_policies_per_env = {
"aloha_real": ["act_aloha_real"], "aloha_real": ["act_aloha_real"],
} }
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks] env_task_pairs = [
(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks
]
env_dataset_pairs = [ env_dataset_pairs = [
(env, dataset) for env, datasets in available_datasets_per_env.items() for dataset in datasets (env, dataset)
for env, datasets in available_datasets_per_env.items()
for dataset in datasets
] ]
env_dataset_policy_triplets = [ env_dataset_policy_triplets = [
(env, dataset, policy) (env, dataset, policy)

View File

@ -127,7 +127,9 @@ class AsyncImageWriter:
self._stopped = False self._stopped = False
if num_threads <= 0 and num_processes <= 0: if num_threads <= 0 and num_processes <= 0:
raise ValueError("Number of threads and processes must be greater than zero.") raise ValueError(
"Number of threads and processes must be greater than zero."
)
if self.num_processes == 0: if self.num_processes == 0:
# Use threading # Use threading
@ -141,12 +143,16 @@ class AsyncImageWriter:
# Use multiprocessing # Use multiprocessing
self.queue = multiprocessing.JoinableQueue() self.queue = multiprocessing.JoinableQueue()
for _ in range(self.num_processes): for _ in range(self.num_processes):
p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads)) p = multiprocessing.Process(
target=worker_process, args=(self.queue, self.num_threads)
)
p.daemon = True p.daemon = True
p.start() p.start()
self.processes.append(p) self.processes.append(p)
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path): def save_image(
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path
):
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
# Convert tensor to numpy array to minimize main process time # Convert tensor to numpy array to minimize main process time
image = image.cpu().numpy() image = image.cpu().numpy()

View File

@ -139,7 +139,9 @@ class LeRobotDatasetMetadata:
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
ep_chunk = self.get_episode_chunk(ep_index) ep_chunk = self.get_episode_chunk(ep_index)
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index) fpath = self.video_path.format(
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index
)
return Path(fpath) return Path(fpath)
def get_episode_chunk(self, ep_index: int) -> int: def get_episode_chunk(self, ep_index: int) -> int:
@ -183,7 +185,11 @@ class LeRobotDatasetMetadata:
@property @property
def camera_keys(self) -> list[str]: def camera_keys(self) -> list[str]:
"""Keys to access visual modalities (regardless of their storage method).""" """Keys to access visual modalities (regardless of their storage method)."""
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]] return [
key
for key, ft in self.features.items()
if ft["dtype"] in ["video", "image"]
]
@property @property
def names(self) -> dict[str, list | dict]: def names(self) -> dict[str, list | dict]:
@ -285,7 +291,9 @@ class LeRobotDatasetMetadata:
""" """
for key in self.video_keys: for key in self.video_keys:
if not self.features[key].get("info", None): if not self.features[key].get("info", None):
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key) video_path = self.root / self.get_video_file_path(
ep_index=0, vid_key=key
)
self.info["features"][key]["info"] = get_video_info(video_path) self.info["features"][key]["info"] = get_video_info(video_path)
def __repr__(self): def __repr__(self):
@ -619,7 +627,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
path = str(self.root / "data") path = str(self.root / "data")
hf_dataset = load_dataset("parquet", data_dir=path, split="train") hf_dataset = load_dataset("parquet", data_dir=path, split="train")
else: else:
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes] files = [
str(self.root / self.meta.get_data_file_path(ep_idx))
for ep_idx in self.episodes
]
hf_dataset = load_dataset("parquet", data_files=files, split="train") hf_dataset = load_dataset("parquet", data_files=files, split="train")
# TODO(aliberts): hf_dataset.set_format("torch") # TODO(aliberts): hf_dataset.set_format("torch")
@ -643,12 +654,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property @property
def num_frames(self) -> int: def num_frames(self) -> int:
"""Number of frames in selected episodes.""" """Number of frames in selected episodes."""
return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames return (
len(self.hf_dataset)
if self.hf_dataset is not None
else self.meta.total_frames
)
@property @property
def num_episodes(self) -> int: def num_episodes(self) -> int:
"""Number of episodes selected.""" """Number of episodes selected."""
return len(self.episodes) if self.episodes is not None else self.meta.total_episodes return (
len(self.episodes)
if self.episodes is not None
else self.meta.total_episodes
)
@property @property
def features(self) -> dict[str, dict]: def features(self) -> dict[str, dict]:
@ -662,16 +681,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
else: else:
return get_hf_features_from_features(self.features) return get_hf_features_from_features(self.features)
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]: def _get_query_indices(
self, idx: int, ep_idx: int
) -> tuple[dict[str, list[int | bool]]]:
ep_start = self.episode_data_index["from"][ep_idx] ep_start = self.episode_data_index["from"][ep_idx]
ep_end = self.episode_data_index["to"][ep_idx] ep_end = self.episode_data_index["to"][ep_idx]
query_indices = { query_indices = {
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx] key: [
max(ep_start.item(), min(ep_end.item() - 1, idx + delta))
for delta in delta_idx
]
for key, delta_idx in self.delta_indices.items() for key, delta_idx in self.delta_indices.items()
} }
padding = { # Pad values outside of current episode range padding = { # Pad values outside of current episode range
f"{key}_is_pad": torch.BoolTensor( f"{key}_is_pad": torch.BoolTensor(
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx] [
(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item())
for delta in delta_idx
]
) )
for key, delta_idx in self.delta_indices.items() for key, delta_idx in self.delta_indices.items()
} }
@ -771,13 +798,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
ep_buffer[key] = current_ep_idx if key == "episode_index" else [] ep_buffer[key] = current_ep_idx if key == "episode_index" else []
return ep_buffer return ep_buffer
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path: def _get_image_file_path(
self, episode_index: int, image_key: str, frame_index: int
) -> Path:
fpath = DEFAULT_IMAGE_PATH.format( fpath = DEFAULT_IMAGE_PATH.format(
image_key=image_key, episode_index=episode_index, frame_index=frame_index image_key=image_key, episode_index=episode_index, frame_index=frame_index
) )
return self.root / fpath return self.root / fpath
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None: def _save_image(
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path
) -> None:
if self.image_writer is None: if self.image_writer is None:
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
image = image.cpu().numpy() image = image.cpu().numpy()
@ -803,7 +834,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Automatically add frame_index and timestamp to episode buffer # Automatically add frame_index and timestamp to episode buffer
frame_index = self.episode_buffer["size"] frame_index = self.episode_buffer["size"]
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps timestamp = (
frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
)
self.episode_buffer["frame_index"].append(frame_index) self.episode_buffer["frame_index"].append(frame_index)
self.episode_buffer["timestamp"].append(timestamp) self.episode_buffer["timestamp"].append(timestamp)
@ -821,7 +854,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
if self.features[key]["dtype"] in ["image", "video"]: if self.features[key]["dtype"] in ["image", "video"]:
img_path = self._get_image_file_path( img_path = self._get_image_file_path(
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index episode_index=self.episode_buffer["episode_index"],
image_key=key,
frame_index=frame_index,
) )
if frame_index == 0: if frame_index == 0:
img_path.parent.mkdir(parents=True, exist_ok=True) img_path.parent.mkdir(parents=True, exist_ok=True)
@ -1132,7 +1167,13 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def features(self) -> datasets.Features: def features(self) -> datasets.Features:
features = {} features = {}
for dataset in self._datasets: for dataset in self._datasets:
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features}) features.update(
{
k: v
for k, v in dataset.hf_features.items()
if k not in self.disabled_features
}
)
return features return features
@property @property
@ -1193,7 +1234,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
continue continue
break break
else: else:
raise AssertionError("We expect the loop to break out as long as the index is within bounds.") raise AssertionError(
"We expect the loop to break out as long as the index is within bounds."
)
item = self._datasets[dataset_idx][idx - start_idx] item = self._datasets[dataset_idx][idx - start_idx]
item["dataset_index"] = torch.tensor(dataset_idx) item["dataset_index"] = torch.tensor(dataset_idx)
for data_key in self.disabled_features: for data_key in self.disabled_features:

View File

@ -131,7 +131,9 @@ class OnlineBuffer(torch.utils.data.Dataset):
else: else:
self._delta_timestamps = None self._delta_timestamps = None
def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]: def _make_data_spec(
self, data_spec: dict[str, Any], buffer_capacity: int
) -> dict[str, dict[str, Any]]:
"""Makes the data spec for np.memmap.""" """Makes the data spec for np.memmap."""
if any(k.startswith("_") for k in data_spec): if any(k.startswith("_") for k in data_spec):
raise ValueError( raise ValueError(
@ -154,14 +156,32 @@ class OnlineBuffer(torch.utils.data.Dataset):
OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()}, OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()},
# Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied # Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied
# with real data rather than the dummy initialization. # with real data rather than the dummy initialization.
OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)}, OnlineBuffer.OCCUPANCY_MASK_KEY: {
OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)}, "dtype": np.dtype("?"),
OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)}, "shape": (buffer_capacity,),
OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)}, },
OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)}, OnlineBuffer.INDEX_KEY: {
"dtype": np.dtype("int64"),
"shape": (buffer_capacity,),
},
OnlineBuffer.FRAME_INDEX_KEY: {
"dtype": np.dtype("int64"),
"shape": (buffer_capacity,),
},
OnlineBuffer.EPISODE_INDEX_KEY: {
"dtype": np.dtype("int64"),
"shape": (buffer_capacity,),
},
OnlineBuffer.TIMESTAMP_KEY: {
"dtype": np.dtype("float64"),
"shape": (buffer_capacity,),
},
} }
for k, v in data_spec.items(): for k, v in data_spec.items():
complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])} complete_data_spec[k] = {
"dtype": v["dtype"],
"shape": (buffer_capacity, *v["shape"]),
}
return complete_data_spec return complete_data_spec
def add_data(self, data: dict[str, np.ndarray]): def add_data(self, data: dict[str, np.ndarray]):
@ -188,7 +208,9 @@ class OnlineBuffer(torch.utils.data.Dataset):
# Shift the incoming indices if necessary. # Shift the incoming indices if necessary.
if self.num_frames > 0: if self.num_frames > 0:
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1] last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][
next_index - 1
]
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1] last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1 data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
data[OnlineBuffer.INDEX_KEY] += last_data_index + 1 data[OnlineBuffer.INDEX_KEY] += last_data_index + 1
@ -223,7 +245,11 @@ class OnlineBuffer(torch.utils.data.Dataset):
@property @property
def num_episodes(self) -> int: def num_episodes(self) -> int:
return len( return len(
np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]) np.unique(
self._data[OnlineBuffer.EPISODE_INDEX_KEY][
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]
]
)
) )
@property @property
@ -261,7 +287,9 @@ class OnlineBuffer(torch.utils.data.Dataset):
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY], self._data[OnlineBuffer.OCCUPANCY_MASK_KEY],
) )
)[0] )[0]
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices] episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][
episode_data_indices
]
for data_key in self.delta_timestamps: for data_key in self.delta_timestamps:
# Note: The logic in this loop is copied from `load_previous_and_future_frames`. # Note: The logic in this loop is copied from `load_previous_and_future_frames`.
@ -278,7 +306,8 @@ class OnlineBuffer(torch.utils.data.Dataset):
# Check violated query timestamps are all outside the episode range. # Check violated query timestamps are all outside the episode range.
assert ( assert (
(query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad]) (query_ts[is_pad] < episode_timestamps[0])
| (episode_timestamps[-1] < query_ts[is_pad])
).all(), ( ).all(), (
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}" f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}"
") inside the episode range." ") inside the episode range."
@ -293,7 +322,9 @@ class OnlineBuffer(torch.utils.data.Dataset):
def get_data_by_key(self, key: str) -> torch.Tensor: def get_data_by_key(self, key: str) -> torch.Tensor:
"""Returns all data for a given data key as a Tensor.""" """Returns all data for a given data key as a Tensor."""
return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]) return torch.from_numpy(
self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]
)
def compute_sampler_weights( def compute_sampler_weights(
@ -324,13 +355,19 @@ def compute_sampler_weights(
- Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not - Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not
included here to avoid adding complexity. included here to avoid adding complexity.
""" """
if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0): if len(offline_dataset) == 0 and (
raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.") online_dataset is None or len(online_dataset) == 0
):
raise ValueError(
"At least one of `offline_dataset` or `online_dataset` should be contain data."
)
if (online_dataset is None) ^ (online_sampling_ratio is None): if (online_dataset is None) ^ (online_sampling_ratio is None):
raise ValueError( raise ValueError(
"`online_dataset` and `online_sampling_ratio` must be provided together or not at all." "`online_dataset` and `online_sampling_ratio` must be provided together or not at all."
) )
offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio offline_sampling_ratio = (
0 if online_sampling_ratio is None else 1 - online_sampling_ratio
)
weights = [] weights = []

View File

@ -45,7 +45,9 @@ def concatenate_episodes(ep_dicts):
return data_dict return data_dict
def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4): def save_images_concurrently(
imgs_array: numpy.array, out_dir: Path, max_workers: int = 4
):
out_dir = Path(out_dir) out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True) out_dir.mkdir(parents=True, exist_ok=True)
@ -55,7 +57,10 @@ def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers
num_images = len(imgs_array) num_images = len(imgs_array)
with ThreadPoolExecutor(max_workers=max_workers) as executor: with ThreadPoolExecutor(max_workers=max_workers) as executor:
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)] [
executor.submit(save_image, imgs_array[i], i, out_dir)
for i in range(num_images)
]
def get_default_encoding() -> dict: def get_default_encoding() -> dict:
@ -64,7 +69,8 @@ def get_default_encoding() -> dict:
return { return {
k: v.default k: v.default
for k, v in signature.parameters.items() for k, v in signature.parameters.items()
if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"] if v.default is not inspect.Parameter.empty
and k in ["vcodec", "pix_fmt", "g", "crf"]
} }
@ -77,7 +83,9 @@ def check_repo_id(repo_id: str) -> None:
# TODO(aliberts): remove # TODO(aliberts): remove
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]: def calculate_episode_data_index(
hf_dataset: datasets.Dataset,
) -> Dict[str, torch.Tensor]:
""" """
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset. Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.

View File

@ -43,7 +43,10 @@ class EpisodeAwareSampler:
): ):
if episode_indices_to_use is None or episode_idx in episode_indices_to_use: if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
indices.extend( indices.extend(
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames) range(
start_index.item() + drop_n_first_frames,
end_index.item() - drop_n_last_frames,
)
) )
self.indices = indices self.indices = indices

View File

@ -58,7 +58,9 @@ class RandomSubsetApply(Transform):
elif not isinstance(n_subset, int): elif not isinstance(n_subset, int):
raise TypeError("n_subset should be an int or None") raise TypeError("n_subset should be an int or None")
elif not (1 <= n_subset <= len(transforms)): elif not (1 <= n_subset <= len(transforms)):
raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]") raise ValueError(
f"n_subset should be in the interval [1, {len(transforms)}]"
)
self.transforms = transforms self.transforms = transforms
total = sum(p) total = sum(p)
@ -119,16 +121,22 @@ class SharpnessJitter(Transform):
def _check_input(self, sharpness): def _check_input(self, sharpness):
if isinstance(sharpness, (int, float)): if isinstance(sharpness, (int, float)):
if sharpness < 0: if sharpness < 0:
raise ValueError("If sharpness is a single number, it must be non negative.") raise ValueError(
"If sharpness is a single number, it must be non negative."
)
sharpness = [1.0 - sharpness, 1.0 + sharpness] sharpness = [1.0 - sharpness, 1.0 + sharpness]
sharpness[0] = max(sharpness[0], 0.0) sharpness[0] = max(sharpness[0], 0.0)
elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2: elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2:
sharpness = [float(v) for v in sharpness] sharpness = [float(v) for v in sharpness]
else: else:
raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.") raise TypeError(
f"{sharpness=} should be a single number or a sequence with length 2."
)
if not 0.0 <= sharpness[0] <= sharpness[1]: if not 0.0 <= sharpness[0] <= sharpness[1]:
raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.") raise ValueError(
f"sharpnesss values should be between (0., inf), but got {sharpness}."
)
return float(sharpness[0]), float(sharpness[1]) return float(sharpness[0]), float(sharpness[1])

View File

@ -52,9 +52,15 @@ STATS_PATH = "meta/stats.json"
EPISODES_STATS_PATH = "meta/episodes_stats.jsonl" EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
TASKS_PATH = "meta/tasks.jsonl" TASKS_PATH = "meta/tasks.jsonl"
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" DEFAULT_VIDEO_PATH = (
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet" "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png" )
DEFAULT_PARQUET_PATH = (
"data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
)
DEFAULT_IMAGE_PATH = (
"images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
)
DATASET_CARD_TEMPLATE = """ DATASET_CARD_TEMPLATE = """
--- ---
@ -540,7 +546,10 @@ def check_timestamps_sync(
def check_delta_timestamps( def check_delta_timestamps(
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True delta_timestamps: dict[str, list[float]],
fps: int,
tolerance_s: float,
raise_value_error: bool = True,
) -> bool: ) -> bool:
"""This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance. """This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance.
This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be
@ -548,10 +557,14 @@ def check_delta_timestamps(
""" """
outside_tolerance = {} outside_tolerance = {}
for key, delta_ts in delta_timestamps.items(): for key, delta_ts in delta_timestamps.items():
within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts] within_tolerance = [
abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts
]
if not all(within_tolerance): if not all(within_tolerance):
outside_tolerance[key] = [ outside_tolerance[key] = [
ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within ts
for ts, is_within in zip(delta_ts, within_tolerance, strict=True)
if not is_within
] ]
if len(outside_tolerance) > 0: if len(outside_tolerance) > 0:
@ -569,7 +582,9 @@ def check_delta_timestamps(
return True return True
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]: def get_delta_indices(
delta_timestamps: dict[str, list[float]], fps: int
) -> dict[str, list[int]]:
delta_indices = {} delta_indices = {}
for key, delta_ts in delta_timestamps.items(): for key, delta_ts in delta_timestamps.items():
delta_indices[key] = [round(d * fps) for d in delta_ts] delta_indices[key] = [round(d * fps) for d in delta_ts]
@ -634,7 +649,9 @@ def create_lerobot_dataset_card(
], ],
) )
card_template = (importlib.resources.files("lerobot.common.datasets") / "card_template.md").read_text() card_template = (
importlib.resources.files("lerobot.common.datasets") / "card_template.md"
).read_text()
return DatasetCard.from_template( return DatasetCard.from_template(
card_data=card_data, card_data=card_data,

View File

@ -118,7 +118,10 @@ DATASETS = {
"single_task": "Place the battery into the slot of the remote controller.", "single_task": "Place the battery into the slot of the remote controller.",
**ALOHA_STATIC_INFO, **ALOHA_STATIC_INFO,
}, },
"aloha_static_candy": {"single_task": "Pick up the candy and unwrap it.", **ALOHA_STATIC_INFO}, "aloha_static_candy": {
"single_task": "Pick up the candy and unwrap it.",
**ALOHA_STATIC_INFO,
},
"aloha_static_coffee": { "aloha_static_coffee": {
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.", "single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.",
**ALOHA_STATIC_INFO, **ALOHA_STATIC_INFO,
@ -167,13 +170,22 @@ DATASETS = {
"single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.", "single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
**ALOHA_STATIC_INFO, **ALOHA_STATIC_INFO,
}, },
"aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO}, "aloha_static_ziploc_slide": {
"aloha_sim_insertion_scripted": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO}, "single_task": "Slide open the ziploc bag.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_scripted": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_scripted_image": { "aloha_sim_insertion_scripted_image": {
"single_task": "Insert the peg into the socket.", "single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO, **ALOHA_STATIC_INFO,
}, },
"aloha_sim_insertion_human": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO}, "aloha_sim_insertion_human": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_human_image": { "aloha_sim_insertion_human_image": {
"single_task": "Insert the peg into the socket.", "single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO, **ALOHA_STATIC_INFO,
@ -194,10 +206,19 @@ DATASETS = {
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.", "single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
**ALOHA_STATIC_INFO, **ALOHA_STATIC_INFO,
}, },
"pusht": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO}, "pusht": {
"pusht_image": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO}, "single_task": "Push the T-shaped block onto the T-shaped target.",
**PUSHT_INFO,
},
"pusht_image": {
"single_task": "Push the T-shaped block onto the T-shaped target.",
**PUSHT_INFO,
},
"unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO}, "unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO},
"unitreeh1_rearrange_objects": {"single_task": "Put the object into the bin.", **UNITREEH_INFO}, "unitreeh1_rearrange_objects": {
"single_task": "Put the object into the bin.",
**UNITREEH_INFO,
},
"unitreeh1_two_robot_greeting": { "unitreeh1_two_robot_greeting": {
"single_task": "Greet the other robot with a high five.", "single_task": "Greet the other robot with a high five.",
**UNITREEH_INFO, **UNITREEH_INFO,
@ -207,13 +228,31 @@ DATASETS = {
**UNITREEH_INFO, **UNITREEH_INFO,
}, },
"xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO}, "xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO}, "xarm_lift_medium_image": {
"xarm_lift_medium_replay": {"single_task": "Pick up the cube and lift it.", **XARM_INFO}, "single_task": "Pick up the cube and lift it.",
"xarm_lift_medium_replay_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO}, **XARM_INFO,
},
"xarm_lift_medium_replay": {
"single_task": "Pick up the cube and lift it.",
**XARM_INFO,
},
"xarm_lift_medium_replay_image": {
"single_task": "Pick up the cube and lift it.",
**XARM_INFO,
},
"xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO}, "xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_image": {"single_task": "Push the cube onto the target.", **XARM_INFO}, "xarm_push_medium_image": {
"xarm_push_medium_replay": {"single_task": "Push the cube onto the target.", **XARM_INFO}, "single_task": "Push the cube onto the target.",
"xarm_push_medium_replay_image": {"single_task": "Push the cube onto the target.", **XARM_INFO}, **XARM_INFO,
},
"xarm_push_medium_replay": {
"single_task": "Push the cube onto the target.",
**XARM_INFO,
},
"xarm_push_medium_replay_image": {
"single_task": "Push the cube onto the target.",
**XARM_INFO,
},
"umi_cup_in_the_wild": { "umi_cup_in_the_wild": {
"single_task": "Put the cup on the plate.", "single_task": "Put the cup on the plate.",
"license": "apache-2.0", "license": "apache-2.0",

View File

@ -218,7 +218,9 @@ def get_features_from_hf_dataset(
dtype = ft.feature.dtype dtype = ft.feature.dtype
shape = (ft.length,) shape = (ft.length,)
motor_names = ( motor_names = (
robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)] robot_config["names"][key]
if robot_config
else [f"motor_{i}" for i in range(ft.length)]
) )
assert len(motor_names) == shape[0] assert len(motor_names) == shape[0]
names = {"motors": motor_names} names = {"motors": motor_names}
@ -242,11 +244,15 @@ def get_features_from_hf_dataset(
return features return features
def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]: def add_task_index_by_episodes(
dataset: Dataset, tasks_by_episodes: dict
) -> tuple[Dataset, list[str]]:
df = dataset.to_pandas() df = dataset.to_pandas()
tasks = list(set(tasks_by_episodes.values())) tasks = list(set(tasks_by_episodes.values()))
tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)} tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)}
episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()} episodes_to_task_index = {
ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()
}
df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int) df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int)
features = dataset.features features = dataset.features
@ -263,10 +269,19 @@ def add_task_index_from_tasks_col(
# HACK: This is to clean some of the instructions in our version of Open X datasets # HACK: This is to clean some of the instructions in our version of Open X datasets
prefix_to_clean = "tf.Tensor(b'" prefix_to_clean = "tf.Tensor(b'"
suffix_to_clean = "', shape=(), dtype=string)" suffix_to_clean = "', shape=(), dtype=string)"
df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean) df[tasks_col] = (
df[tasks_col]
.str.removeprefix(prefix_to_clean)
.str.removesuffix(suffix_to_clean)
)
# Create task_index col # Create task_index col
tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict() tasks_by_episode = (
df.groupby("episode_index")[tasks_col]
.unique()
.apply(lambda x: x.tolist())
.to_dict()
)
tasks = df[tasks_col].unique().tolist() tasks = df[tasks_col].unique().tolist()
tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)} tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)}
df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int) df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int)
@ -291,7 +306,9 @@ def split_parquet_by_episodes(
for ep_chunk in range(total_chunks): for ep_chunk in range(total_chunks):
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes) ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk) chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(
episode_chunk=ep_chunk
)
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True) (output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
for ep_idx in range(ep_chunk_start, ep_chunk_end): for ep_idx in range(ep_chunk_start, ep_chunk_end):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
@ -323,7 +340,9 @@ def move_videos(
videos_moved = False videos_moved = False
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")] video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
if len(video_files) == 0: if len(video_files) == 0:
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")] video_files = [
str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")
]
videos_moved = True # Videos have already been moved videos_moved = True # Videos have already been moved
assert len(video_files) == total_episodes * len(video_keys) assert len(video_files) == total_episodes * len(video_keys)
@ -354,7 +373,9 @@ def move_videos(
target_path = DEFAULT_VIDEO_PATH.format( target_path = DEFAULT_VIDEO_PATH.format(
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
) )
video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx) video_file = V1_VIDEO_FILE.format(
video_key=vid_key, episode_index=ep_idx
)
if len(video_dirs) == 1: if len(video_dirs) == 1:
video_path = video_dirs[0] / video_file video_path = video_dirs[0] / video_file
else: else:
@ -371,7 +392,9 @@ def move_videos(
subprocess.run(["git", "push"], cwd=work_dir, check=True) subprocess.run(["git", "push"], cwd=work_dir, check=True)
def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None: def fix_lfs_video_files_tracking(
work_dir: Path, lfs_untracked_videos: list[str]
) -> None:
""" """
HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case, HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
there's no other option than to download the actual files and reupload them with lfs tracking. there's no other option than to download the actual files and reupload them with lfs tracking.
@ -379,7 +402,12 @@ def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]
for i in range(0, len(lfs_untracked_videos), 100): for i in range(0, len(lfs_untracked_videos), 100):
files = lfs_untracked_videos[i : i + 100] files = lfs_untracked_videos[i : i + 100]
try: try:
subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True) subprocess.run(
["git", "rm", "--cached", *files],
cwd=work_dir,
capture_output=True,
check=True,
)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print("git rm --cached ERROR:") print("git rm --cached ERROR:")
print(e.stderr) print(e.stderr)
@ -390,10 +418,14 @@ def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]
subprocess.run(["git", "push"], cwd=work_dir, check=True) subprocess.run(["git", "push"], cwd=work_dir, check=True)
def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None: def fix_gitattributes(
work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path
) -> None:
shutil.copyfile(clean_gittatributes, current_gittatributes) shutil.copyfile(clean_gittatributes, current_gittatributes)
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True) subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True) subprocess.run(
["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True
)
subprocess.run(["git", "push"], cwd=work_dir, check=True) subprocess.run(["git", "push"], cwd=work_dir, check=True)
@ -402,7 +434,17 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
repo_url = f"https://huggingface.co/datasets/{repo_id}" repo_url = f"https://huggingface.co/datasets/{repo_id}"
env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
subprocess.run( subprocess.run(
["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)], [
"git",
"clone",
"--branch",
branch,
"--single-branch",
"--depth",
"1",
repo_url,
str(work_dir),
],
check=True, check=True,
env=env, env=env,
) )
@ -410,13 +452,19 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]: def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]:
lfs_tracked_files = subprocess.run( lfs_tracked_files = subprocess.run(
["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True ["git", "lfs", "ls-files", "-n"],
cwd=work_dir,
capture_output=True,
text=True,
check=True,
) )
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines()) lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
return [f for f in video_files if f not in lfs_tracked_files] return [f for f in video_files if f not in lfs_tracked_files]
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict: def get_videos_info(
repo_id: str, local_dir: Path, video_keys: list[str], branch: str
) -> dict:
# Assumes first episode # Assumes first episode
video_files = [ video_files = [
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0) DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
@ -424,7 +472,11 @@ def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch
] ]
hub_api = HfApi() hub_api = HfApi()
hub_api.snapshot_download( hub_api.snapshot_download(
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files repo_id=repo_id,
repo_type="dataset",
local_dir=local_dir,
revision=branch,
allow_patterns=video_files,
) )
videos_info_dict = {} videos_info_dict = {}
for vid_key, vid_path in zip(video_keys, video_files, strict=True): for vid_key, vid_path in zip(video_keys, video_files, strict=True):
@ -451,7 +503,11 @@ def convert_dataset(
hub_api = HfApi() hub_api = HfApi()
hub_api.snapshot_download( hub_api.snapshot_download(
repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/" repo_id=repo_id,
repo_type="dataset",
revision=v1,
local_dir=v1x_dir,
ignore_patterns="videos*/",
) )
branch = "main" branch = "main"
if test_branch: if test_branch:
@ -483,19 +539,31 @@ def convert_dataset(
if single_task: if single_task:
tasks_by_episodes = {ep_idx: single_task for ep_idx in episode_indices} tasks_by_episodes = {ep_idx: single_task for ep_idx in episode_indices}
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes) dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()} tasks_by_episodes = {
ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()
}
elif tasks_path: elif tasks_path:
tasks_by_episodes = load_json(tasks_path) tasks_by_episodes = load_json(tasks_path)
tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()} tasks_by_episodes = {
int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()
}
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes) dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()} tasks_by_episodes = {
ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()
}
elif tasks_col: elif tasks_col:
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col) dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(
dataset, tasks_col
)
else: else:
raise ValueError raise ValueError
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks} assert set(tasks) == {
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)] task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks
}
tasks = [
{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)
]
write_jsonlines(tasks, v20_dir / TASKS_PATH) write_jsonlines(tasks, v20_dir / TASKS_PATH)
features["task_index"] = { features["task_index"] = {
"dtype": "int64", "dtype": "int64",
@ -509,14 +577,25 @@ def convert_dataset(
dataset = dataset.remove_columns(video_keys) dataset = dataset.remove_columns(video_keys)
clean_gitattr = Path( clean_gitattr = Path(
hub_api.hf_hub_download( hub_api.hf_hub_download(
repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes" repo_id=GITATTRIBUTES_REF,
repo_type="dataset",
local_dir=local_dir,
filename=".gitattributes",
) )
).absolute() ).absolute()
with tempfile.TemporaryDirectory() as tmp_video_dir: with tempfile.TemporaryDirectory() as tmp_video_dir:
move_videos( move_videos(
repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch repo_id,
video_keys,
total_episodes,
total_chunks,
Path(tmp_video_dir),
clean_gitattr,
branch,
) )
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch) videos_info = get_videos_info(
repo_id, v1x_dir, video_keys=video_keys, branch=branch
)
for key in video_keys: for key in video_keys:
features[key]["shape"] = ( features[key]["shape"] = (
videos_info[key].pop("video.height"), videos_info[key].pop("video.height"),
@ -524,15 +603,22 @@ def convert_dataset(
videos_info[key].pop("video.channels"), videos_info[key].pop("video.channels"),
) )
features[key]["video_info"] = videos_info[key] features[key]["video_info"] = videos_info[key]
assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3) assert math.isclose(
videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3
)
if "encoding" in metadata_v1: if "encoding" in metadata_v1:
assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"] assert (
videos_info[key]["video.pix_fmt"]
== metadata_v1["encoding"]["pix_fmt"]
)
else: else:
assert metadata_v1.get("video", 0) == 0 assert metadata_v1.get("video", 0) == 0
videos_info = None videos_info = None
# Split data into 1 parquet file by episode # Split data into 1 parquet file by episode
episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir) episode_lengths = split_parquet_by_episodes(
dataset, total_episodes, total_chunks, v20_dir
)
if robot_config is not None: if robot_config is not None:
robot_type = robot_config.type robot_type = robot_config.type
@ -543,7 +629,11 @@ def convert_dataset(
# Episodes # Episodes
episodes = [ episodes = [
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]} {
"episode_index": ep_idx,
"tasks": tasks_by_episodes[ep_idx],
"length": episode_lengths[ep_idx],
}
for ep_idx in episode_indices for ep_idx in episode_indices
] ]
write_jsonlines(episodes, v20_dir / EPISODES_PATH) write_jsonlines(episodes, v20_dir / EPISODES_PATH)
@ -566,16 +656,27 @@ def convert_dataset(
} }
write_json(metadata_v2_0, v20_dir / INFO_PATH) write_json(metadata_v2_0, v20_dir / INFO_PATH)
convert_stats_to_json(v1x_dir, v20_dir) convert_stats_to_json(v1x_dir, v20_dir)
card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs) card = create_lerobot_dataset_card(
tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs
)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError): with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch) hub_api.delete_folder(
repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch
)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError): with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch) hub_api.delete_folder(
repo_id=repo_id,
path_in_repo="meta_data",
repo_type="dataset",
revision=branch,
)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError): with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch) hub_api.delete_folder(
repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch
)
hub_api.upload_folder( hub_api.upload_folder(
repo_id=repo_id, repo_id=repo_id,

View File

@ -344,7 +344,9 @@ def get_audio_info(video_path: Path | str) -> dict:
"json", "json",
str(video_path), str(video_path),
] ]
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) result = subprocess.run(
ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
if result.returncode != 0: if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}") raise RuntimeError(f"Error running ffprobe: {result.stderr}")
@ -358,7 +360,9 @@ def get_audio_info(video_path: Path | str) -> dict:
"has_audio": True, "has_audio": True,
"audio.channels": audio_stream_info.get("channels", None), "audio.channels": audio_stream_info.get("channels", None),
"audio.codec": audio_stream_info.get("codec_name", None), "audio.codec": audio_stream_info.get("codec_name", None),
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None, "audio.bit_rate": int(audio_stream_info["bit_rate"])
if audio_stream_info.get("bit_rate")
else None,
"audio.sample_rate": int(audio_stream_info["sample_rate"]) "audio.sample_rate": int(audio_stream_info["sample_rate"])
if audio_stream_info.get("sample_rate") if audio_stream_info.get("sample_rate")
else None, else None,
@ -380,7 +384,9 @@ def get_video_info(video_path: Path | str) -> dict:
"json", "json",
str(video_path), str(video_path),
] ]
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) result = subprocess.run(
ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
if result.returncode != 0: if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}") raise RuntimeError(f"Error running ffprobe: {result.stderr}")

View File

@ -70,7 +70,9 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
return env return env
def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None: def make_maniskill_env(
cfg: DictConfig, n_envs: int | None = None
) -> gym.vector.VectorEnv | None:
"""Make ManiSkill3 gym environment""" """Make ManiSkill3 gym environment"""
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
@ -87,7 +89,9 @@ def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector
# state should have the size of 25 # state should have the size of 25
# env = ConvertToLeRobotEnv(env, n_envs) # env = ConvertToLeRobotEnv(env, n_envs)
# env = PixelWrapper(cfg, env, n_envs) # env = PixelWrapper(cfg, env, n_envs)
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env) env._max_episode_steps = env.max_episode_steps = (
50 # gym_utils.find_max_episode_steps_value(env)
)
env.unwrapped.metadata["render_fps"] = 20 env.unwrapped.metadata["render_fps"] = 20
return env return env
@ -114,7 +118,11 @@ class PixelWrapper(gym.Wrapper):
def _get_obs(self, obs): def _get_obs(self, obs):
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2) frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
self._frames.append(frame) self._frames.append(frame)
return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)} return {
"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(
self.env.device
)
}
def reset(self, seed): def reset(self, seed):
obs, info = self.env.reset() # (seed=seed) obs, info = self.env.reset() # (seed=seed)
@ -148,7 +156,9 @@ class ConvertToLeRobotEnv(gym.Wrapper):
images = torch.concat(images, axis=-1) images = torch.concat(images, axis=-1)
# flatten the rest of the data which should just be state data # flatten the rest of the data which should just be state data
observation = common.flatten_state_dict(observation, use_torch=True, device=self.base_env.device) observation = common.flatten_state_dict(
observation, use_torch=True, device=self.base_env.device
)
ret = dict() ret = dict()
ret["state"] = observation ret["state"] = observation
ret["pixels"] = images ret["pixels"] = images

View File

@ -84,7 +84,9 @@ class Logger:
pretrained_model_dir_name = "pretrained_model" pretrained_model_dir_name = "pretrained_model"
training_state_file_name = "training_state.pth" training_state_file_name = "training_state.pth"
def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None): def __init__(
self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None
):
""" """
Args: Args:
log_dir: The directory to save all logs and training outputs to. log_dir: The directory to save all logs and training outputs to.
@ -104,7 +106,9 @@ class Logger:
enable_wandb = cfg.get("wandb", {}).get("enable", False) enable_wandb = cfg.get("wandb", {}).get("enable", False)
run_offline = not enable_wandb or not project run_offline = not enable_wandb or not project
if run_offline: if run_offline:
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) logging.info(
colored("Logs will be saved locally.", "yellow", attrs=["bold"])
)
self._wandb = None self._wandb = None
else: else:
os.environ["WANDB_SILENT"] = "true" os.environ["WANDB_SILENT"] = "true"
@ -130,7 +134,9 @@ class Logger:
# Handle custom step key for rl asynchronous training. # Handle custom step key for rl asynchronous training.
self._wandb_custom_step_key: set[str] | None = None self._wandb_custom_step_key: set[str] | None = None
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") logging.info(
f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}"
)
self._wandb = wandb self._wandb = wandb
@classmethod @classmethod
@ -151,7 +157,9 @@ class Logger:
""" """
return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name
def save_model(self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None): def save_model(
self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None
):
"""Save the weights of the Policy model using PyTorchModelHubMixin. """Save the weights of the Policy model using PyTorchModelHubMixin.
The weights are saved in a folder called "pretrained_model" under the checkpoint directory. The weights are saved in a folder called "pretrained_model" under the checkpoint directory.
@ -221,22 +229,30 @@ class Logger:
else f"{self._group.replace(':', '_').replace('/', '_')}-{self._cfg.seed}-{identifier}" else f"{self._group.replace(':', '_').replace('/', '_')}-{self._cfg.seed}-{identifier}"
) )
self.save_model( self.save_model(
checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name checkpoint_dir / self.pretrained_model_dir_name,
policy,
wandb_artifact_name=wandb_artifact_name,
)
self.save_training_state(
checkpoint_dir, train_step, optimizer, scheduler, interaction_step
) )
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler, interaction_step)
os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir) os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir)
def load_last_training_state(self, optimizer: Optimizer | dict, scheduler: LRScheduler | None) -> int: def load_last_training_state(
self, optimizer: Optimizer | dict, scheduler: LRScheduler | None
) -> int:
""" """
Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and
random state, and return the global training step. random state, and return the global training step.
""" """
training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name) training_state = torch.load(
self.last_checkpoint_dir / self.training_state_file_name
)
# For the case where the optimizer is a dictionary of optimizers (e.g., sac) # For the case where the optimizer is a dictionary of optimizers (e.g., sac)
if type(training_state["optimizer"]) is dict: if type(training_state["optimizer"]) is dict:
assert set(training_state["optimizer"].keys()) == set(optimizer.keys()), ( assert set(training_state["optimizer"].keys()) == set(
"Optimizer dictionaries do not have the same keys during resume!" optimizer.keys()
) ), "Optimizer dictionaries do not have the same keys during resume!"
for k, v in training_state["optimizer"].items(): for k, v in training_state["optimizer"].items():
optimizer[k].load_state_dict(v) optimizer[k].load_state_dict(v)
else: else:
@ -248,10 +264,18 @@ class Logger:
"The checkpoint contains a scheduler state_dict, but no LRScheduler was provided." "The checkpoint contains a scheduler state_dict, but no LRScheduler was provided."
) )
# Small hack to get the expected keys: use `get_global_random_state`. # Small hack to get the expected keys: use `get_global_random_state`.
set_global_random_state({k: training_state[k] for k in get_global_random_state()}) set_global_random_state(
{k: training_state[k] for k in get_global_random_state()}
)
return training_state["step"] return training_state["step"]
def log_dict(self, d, step: int | None = None, mode="train", custom_step_key: str | None = None): def log_dict(
self,
d,
step: int | None = None,
mode="train",
custom_step_key: str | None = None,
):
"""Log a dictionary of metrics to WandB.""" """Log a dictionary of metrics to WandB."""
assert mode in {"train", "eval"} assert mode in {"train", "eval"}
# TODO(alexander-soare): Add local text log. # TODO(alexander-soare): Add local text log.
@ -280,12 +304,20 @@ class Logger:
continue continue
# Do not log the custom step key itself. # Do not log the custom step key itself.
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key: if (
self._wandb_custom_step_key is not None
and k in self._wandb_custom_step_key
):
continue continue
if custom_step_key is not None: if custom_step_key is not None:
value_custom_step = d[custom_step_key] value_custom_step = d[custom_step_key]
self._wandb.log({f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step}) self._wandb.log(
{
f"{mode}/{k}": v,
f"{mode}/{custom_step_key}": value_custom_step,
}
)
continue continue
self._wandb.log(data={f"{mode}/{k}": v}, step=step) self._wandb.log(data={f"{mode}/{k}": v}, step=step)

View File

@ -74,7 +74,9 @@ class ACTPolicy(PreTrainedPolicy):
self.model = ACT(config) self.model = ACT(config)
if config.temporal_ensemble_coeff is not None: if config.temporal_ensemble_coeff is not None:
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size) self.temporal_ensembler = ACTTemporalEnsembler(
config.temporal_ensemble_coeff, config.chunk_size
)
self.reset() self.reset()
@ -153,7 +155,8 @@ class ACTPolicy(PreTrainedPolicy):
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = ( l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) F.l1_loss(batch["action"], actions_hat, reduction="none")
* ~batch["action_is_pad"].unsqueeze(-1)
).mean() ).mean()
loss_dict = {"l1_loss": l1_loss.item()} loss_dict = {"l1_loss": l1_loss.item()}
@ -163,7 +166,12 @@ class ACTPolicy(PreTrainedPolicy):
# KL-divergence per batch element, then take the mean over the batch. # KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details). # (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = ( mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() (
-0.5
* (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())
)
.sum(-1)
.mean()
) )
loss_dict["kld_loss"] = mean_kld.item() loss_dict["kld_loss"] = mean_kld.item()
loss = l1_loss + mean_kld * self.config.kl_weight loss = l1_loss + mean_kld * self.config.kl_weight
@ -217,7 +225,9 @@ class ACTTemporalEnsembler:
``` ```
""" """
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)) self.ensemble_weights = torch.exp(
-temporal_ensemble_coeff * torch.arange(chunk_size)
)
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0) self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
self.reset() self.reset()
@ -233,7 +243,9 @@ class ACTTemporalEnsembler:
time steps, and pop/return the next batch of actions in the sequence. time steps, and pop/return the next batch of actions in the sequence.
""" """
self.ensemble_weights = self.ensemble_weights.to(device=actions.device) self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device) self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(
device=actions.device
)
if self.ensembled_actions is None: if self.ensembled_actions is None:
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first # Initializes `self._ensembled_action` to the sequence of actions predicted during the first
# time step of the episode. # time step of the episode.
@ -241,19 +253,34 @@ class ACTTemporalEnsembler:
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor # Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
# operations later. # operations later.
self.ensembled_actions_count = torch.ones( self.ensembled_actions_count = torch.ones(
(self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device (self.chunk_size, 1),
dtype=torch.long,
device=self.ensembled_actions.device,
) )
else: else:
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute # self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
# the online update for those entries. # the online update for those entries.
self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1] self.ensembled_actions *= self.ensemble_weights_cumsum[
self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count] self.ensembled_actions_count - 1
self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count] ]
self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size) self.ensembled_actions += (
actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
)
self.ensembled_actions /= self.ensemble_weights_cumsum[
self.ensembled_actions_count
]
self.ensembled_actions_count = torch.clamp(
self.ensembled_actions_count + 1, max=self.chunk_size
)
# The last action, which has no prior online average, needs to get concatenated onto the end. # The last action, which has no prior online average, needs to get concatenated onto the end.
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1) self.ensembled_actions = torch.cat(
[self.ensembled_actions, actions[:, -1:]], dim=1
)
self.ensembled_actions_count = torch.cat( self.ensembled_actions_count = torch.cat(
[self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])] [
self.ensembled_actions_count,
torch.ones_like(self.ensembled_actions_count[-1:]),
]
) )
# "Consume" the first action. # "Consume" the first action.
action, self.ensembled_actions, self.ensembled_actions_count = ( action, self.ensembled_actions, self.ensembled_actions_count = (
@ -319,7 +346,9 @@ class ACT(nn.Module):
config.dim_model, config.dim_model,
) )
# Projection layer from the VAE encoder's output to the latent distribution's parameter space. # Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2) self.vae_encoder_latent_output_proj = nn.Linear(
config.dim_model, config.latent_dim * 2
)
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch # Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
# dimension. # dimension.
num_input_token_encoder = 1 + config.chunk_size num_input_token_encoder = 1 + config.chunk_size
@ -327,20 +356,28 @@ class ACT(nn.Module):
num_input_token_encoder += 1 num_input_token_encoder += 1
self.register_buffer( self.register_buffer(
"vae_encoder_pos_enc", "vae_encoder_pos_enc",
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0), create_sinusoidal_pos_embedding(
num_input_token_encoder, config.dim_model
).unsqueeze(0),
) )
# Backbone for image feature extraction. # Backbone for image feature extraction.
if self.config.image_features: if self.config.image_features:
backbone_model = getattr(torchvision.models, config.vision_backbone)( backbone_model = getattr(torchvision.models, config.vision_backbone)(
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation], replace_stride_with_dilation=[
False,
False,
config.replace_final_stride_with_dilation,
],
weights=config.pretrained_backbone_weights, weights=config.pretrained_backbone_weights,
norm_layer=FrozenBatchNorm2d, norm_layer=FrozenBatchNorm2d,
) )
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
# feature map). # feature map).
# Note: The forward method of this returns a dict: {"feature_map": output}. # Note: The forward method of this returns a dict: {"feature_map": output}.
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) self.backbone = IntermediateLayerGetter(
backbone_model, return_layers={"layer4": "feature_map"}
)
# Transformer (acts as VAE decoder when training with the variational objective). # Transformer (acts as VAE decoder when training with the variational objective).
self.encoder = ACTEncoder(config) self.encoder = ACTEncoder(config)
@ -386,7 +423,9 @@ class ACT(nn.Module):
if p.dim() > 1: if p.dim() > 1:
nn.init.xavier_uniform_(p) nn.init.xavier_uniform_(p)
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]: def forward(
self, batch: dict[str, Tensor]
) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder). """A forward pass through the Action Chunking Transformer (with optional VAE encoder).
`batch` should have the following structure: `batch` should have the following structure:
@ -424,7 +463,9 @@ class ACT(nn.Module):
if self.config.robot_state_feature: if self.config.robot_state_feature:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]) robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) action_embed = self.vae_encoder_action_input_proj(
batch["action"]
) # (B, S, D)
if self.config.robot_state_feature: if self.config.robot_state_feature:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D) vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
@ -465,20 +506,24 @@ class ACT(nn.Module):
# When not using the VAE encoder, we set the latent to be all zeros. # When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None mu = log_sigma_x2 = None
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to( latent_sample = torch.zeros(
batch["observation.state"].device [batch_size, self.config.latent_dim], dtype=torch.float32
) ).to(batch["observation.state"].device)
# Prepare transformer encoder inputs. # Prepare transformer encoder inputs.
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)] encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1)) encoder_in_pos_embed = list(
self.encoder_1d_feature_pos_embed.weight.unsqueeze(1)
)
# Robot state token. # Robot state token.
if self.config.robot_state_feature: if self.config.robot_state_feature:
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"])) encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
# Environment state token. # Environment state token.
if self.config.env_state_feature: if self.config.env_state_feature:
encoder_in_tokens.append( encoder_in_tokens.append(
self.encoder_env_state_input_proj(batch["observation.environment_state"]) self.encoder_env_state_input_proj(
batch["observation.environment_state"]
)
) )
# Camera observation features and positional embeddings. # Camera observation features and positional embeddings.
@ -535,12 +580,21 @@ class ACTEncoder(nn.Module):
def __init__(self, config: ACTConfig, is_vae_encoder: bool = False): def __init__(self, config: ACTConfig, is_vae_encoder: bool = False):
super().__init__() super().__init__()
self.is_vae_encoder = is_vae_encoder self.is_vae_encoder = is_vae_encoder
num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers num_layers = (
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)]) config.n_vae_encoder_layers
if self.is_vae_encoder
else config.n_encoder_layers
)
self.layers = nn.ModuleList(
[ACTEncoderLayer(config) for _ in range(num_layers)]
)
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity() self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
def forward( def forward(
self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None self,
x: Tensor,
pos_embed: Tensor | None = None,
key_padding_mask: Tensor | None = None,
) -> Tensor: ) -> Tensor:
for layer in self.layers: for layer in self.layers:
x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask) x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
@ -551,7 +605,9 @@ class ACTEncoder(nn.Module):
class ACTEncoderLayer(nn.Module): class ACTEncoderLayer(nn.Module):
def __init__(self, config: ACTConfig): def __init__(self, config: ACTConfig):
super().__init__() super().__init__()
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) self.self_attn = nn.MultiheadAttention(
config.dim_model, config.n_heads, dropout=config.dropout
)
# Feed forward layers. # Feed forward layers.
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward) self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
@ -566,7 +622,9 @@ class ACTEncoderLayer(nn.Module):
self.activation = get_activation_fn(config.feedforward_activation) self.activation = get_activation_fn(config.feedforward_activation)
self.pre_norm = config.pre_norm self.pre_norm = config.pre_norm
def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor: def forward(
self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
) -> Tensor:
skip = x skip = x
if self.pre_norm: if self.pre_norm:
x = self.norm1(x) x = self.norm1(x)
@ -591,7 +649,9 @@ class ACTDecoder(nn.Module):
def __init__(self, config: ACTConfig): def __init__(self, config: ACTConfig):
"""Convenience module for running multiple decoder layers followed by normalization.""" """Convenience module for running multiple decoder layers followed by normalization."""
super().__init__() super().__init__()
self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]) self.layers = nn.ModuleList(
[ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]
)
self.norm = nn.LayerNorm(config.dim_model) self.norm = nn.LayerNorm(config.dim_model)
def forward( def forward(
@ -603,7 +663,10 @@ class ACTDecoder(nn.Module):
) -> Tensor: ) -> Tensor:
for layer in self.layers: for layer in self.layers:
x = layer( x = layer(
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed x,
encoder_out,
decoder_pos_embed=decoder_pos_embed,
encoder_pos_embed=encoder_pos_embed,
) )
if self.norm is not None: if self.norm is not None:
x = self.norm(x) x = self.norm(x)
@ -613,8 +676,12 @@ class ACTDecoder(nn.Module):
class ACTDecoderLayer(nn.Module): class ACTDecoderLayer(nn.Module):
def __init__(self, config: ACTConfig): def __init__(self, config: ACTConfig):
super().__init__() super().__init__()
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) self.self_attn = nn.MultiheadAttention(
self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) config.dim_model, config.n_heads, dropout=config.dropout
)
self.multihead_attn = nn.MultiheadAttention(
config.dim_model, config.n_heads, dropout=config.dropout
)
# Feed forward layers. # Feed forward layers.
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward) self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
@ -655,7 +722,9 @@ class ACTDecoderLayer(nn.Module):
if self.pre_norm: if self.pre_norm:
x = self.norm1(x) x = self.norm1(x)
q = k = self.maybe_add_pos_embed(x, decoder_pos_embed) q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights x = self.self_attn(q, k, value=x)[
0
] # select just the output, not the attention weights
x = skip + self.dropout1(x) x = skip + self.dropout1(x)
if self.pre_norm: if self.pre_norm:
skip = x skip = x
@ -692,9 +761,14 @@ def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tenso
""" """
def get_position_angle_vec(position): def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)] return [
position / np.power(10000, 2 * (hid_j // 2) / dimension)
for hid_j in range(dimension)
]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)]) sinusoid_table = np.array(
[get_position_angle_vec(pos_i) for pos_i in range(num_positions)]
)
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.from_numpy(sinusoid_table).float() return torch.from_numpy(sinusoid_table).float()
@ -739,7 +813,9 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
inverse_frequency = self._temperature ** ( inverse_frequency = self._temperature ** (
2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension 2
* (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2)
/ self.dimension
) )
x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1) x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
@ -747,9 +823,15 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
# Note: this stack then flatten operation results in interleaved sine and cosine terms. # Note: this stack then flatten operation results in interleaved sine and cosine terms.
# pos_embed_x and pos_embed_y are (1, H, W, C // 2). # pos_embed_x and pos_embed_y are (1, H, W, C // 2).
pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3) pos_embed_x = torch.stack(
pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3) (x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W) ).flatten(3)
pos_embed_y = torch.stack(
(y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1
).flatten(3)
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(
0, 3, 1, 2
) # (1, C, H, W)
return pos_embed return pos_embed

View File

@ -132,7 +132,11 @@ class DiffusionPolicy(PreTrainedPolicy):
if len(self._queues["action"]) == 0: if len(self._queues["action"]) == 0:
# stack n latest observations from the queue # stack n latest observations from the queue
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} batch = {
k: torch.stack(list(self._queues[k]), dim=1)
for k in batch
if k in self._queues
}
actions = self.diffusion.generate_actions(batch) actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary? # TODO(rcadene): make above methods return output dictionary?
@ -189,7 +193,9 @@ class DiffusionModel(nn.Module):
if self.config.env_state_feature: if self.config.env_state_feature:
global_cond_dim += self.config.env_state_feature.shape[0] global_cond_dim += self.config.env_state_feature.shape[0]
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps) self.unet = DiffusionConditionalUnet1d(
config, global_cond_dim=global_cond_dim * config.n_obs_steps
)
self.noise_scheduler = _make_noise_scheduler( self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type, config.noise_scheduler_type,
@ -209,7 +215,10 @@ class DiffusionModel(nn.Module):
# ========= inference ============ # ========= inference ============
def conditional_sample( def conditional_sample(
self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None self,
batch_size: int,
global_cond: Tensor | None = None,
generator: torch.Generator | None = None,
) -> Tensor: ) -> Tensor:
device = get_device_from_parameters(self) device = get_device_from_parameters(self)
dtype = get_dtype_from_parameters(self) dtype = get_dtype_from_parameters(self)
@ -232,7 +241,9 @@ class DiffusionModel(nn.Module):
global_cond=global_cond, global_cond=global_cond,
) )
# Compute previous image: x_t -> x_t-1 # Compute previous image: x_t -> x_t-1
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample sample = self.noise_scheduler.step(
model_output, t, sample, generator=generator
).prev_sample
return sample return sample
@ -244,27 +255,39 @@ class DiffusionModel(nn.Module):
if self.config.image_features: if self.config.image_features:
if self.config.use_separate_rgb_encoder_per_camera: if self.config.use_separate_rgb_encoder_per_camera:
# Combine batch and sequence dims while rearranging to make the camera index dimension first. # Combine batch and sequence dims while rearranging to make the camera index dimension first.
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...") images_per_camera = einops.rearrange(
batch["observation.images"], "b s n ... -> n (b s) ..."
)
img_features_list = torch.cat( img_features_list = torch.cat(
[ [
encoder(images) encoder(images)
for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True) for encoder, images in zip(
self.rgb_encoder, images_per_camera, strict=True
)
] ]
) )
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the # Separate batch and sequence dims back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features). # feature dim (effectively concatenating the camera features).
img_features = einops.rearrange( img_features = einops.rearrange(
img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps img_features_list,
"(n b s) ... -> b s (n ...)",
b=batch_size,
s=n_obs_steps,
) )
else: else:
# Combine batch, sequence, and "which camera" dims before passing to shared encoder. # Combine batch, sequence, and "which camera" dims before passing to shared encoder.
img_features = self.rgb_encoder( img_features = self.rgb_encoder(
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...") einops.rearrange(
batch["observation.images"], "b s n ... -> (b s n) ..."
)
) )
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the # Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features). # feature dim (effectively concatenating the camera features).
img_features = einops.rearrange( img_features = einops.rearrange(
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps img_features,
"(b s n) ... -> b s (n ...)",
b=batch_size,
s=n_obs_steps,
) )
global_cond_feats.append(img_features) global_cond_feats.append(img_features)
@ -350,7 +373,9 @@ class DiffusionModel(nn.Module):
elif self.config.prediction_type == "sample": elif self.config.prediction_type == "sample":
target = batch["action"] target = batch["action"]
else: else:
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}") raise ValueError(
f"Unsupported prediction type {self.config.prediction_type}"
)
loss = F.mse_loss(pred, target, reduction="none") loss = F.mse_loss(pred, target, reduction="none")
@ -410,7 +435,9 @@ class SpatialSoftmax(nn.Module):
# we could use torch.linspace directly but that seems to behave slightly differently than numpy # we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and causes a small degradation in pc_success of pre-trained models. # and causes a small degradation in pc_success of pre-trained models.
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)) pos_x, pos_y = np.meshgrid(
np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)
)
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float() pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float() pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
# register as buffer so it's moved to the correct device. # register as buffer so it's moved to the correct device.
@ -452,7 +479,9 @@ class DiffusionRgbEncoder(nn.Module):
# Always use center crop for eval # Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape) self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
if config.crop_is_random: if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape) self.maybe_random_crop = torchvision.transforms.RandomCrop(
config.crop_shape
)
else: else:
self.maybe_random_crop = self.center_crop self.maybe_random_crop = self.center_crop
else: else:
@ -473,7 +502,9 @@ class DiffusionRgbEncoder(nn.Module):
self.backbone = _replace_submodules( self.backbone = _replace_submodules(
root_module=self.backbone, root_module=self.backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d), predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features), func=lambda x: nn.GroupNorm(
num_groups=x.num_features // 16, num_channels=x.num_features
),
) )
# Set up pooling and final layers. # Set up pooling and final layers.
@ -515,7 +546,9 @@ class DiffusionRgbEncoder(nn.Module):
def _replace_submodules( def _replace_submodules(
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module] root_module: nn.Module,
predicate: Callable[[nn.Module], bool],
func: Callable[[nn.Module], nn.Module],
) -> nn.Module: ) -> nn.Module:
""" """
Args: Args:
@ -528,7 +561,11 @@ def _replace_submodules(
if predicate(root_module): if predicate(root_module):
return func(root_module) return func(root_module)
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] replace_list = [
k.split(".")
for k, m in root_module.named_modules(remove_duplicate=True)
if predicate(m)
]
for *parents, k in replace_list: for *parents, k in replace_list:
parent_module = root_module parent_module = root_module
if len(parents) > 0: if len(parents) > 0:
@ -543,7 +580,9 @@ def _replace_submodules(
else: else:
setattr(parent_module, k, tgt_module) setattr(parent_module, k, tgt_module)
# verify that all BN are replaced # verify that all BN are replaced
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)) assert not any(
predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)
)
return root_module return root_module
@ -571,7 +610,9 @@ class DiffusionConv1dBlock(nn.Module):
super().__init__() super().__init__()
self.block = nn.Sequential( self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), nn.Conv1d(
inp_channels, out_channels, kernel_size, padding=kernel_size // 2
),
nn.GroupNorm(n_groups, out_channels), nn.GroupNorm(n_groups, out_channels),
nn.Mish(), nn.Mish(),
) )
@ -594,9 +635,13 @@ class DiffusionConditionalUnet1d(nn.Module):
# Encoder for the diffusion timestep. # Encoder for the diffusion timestep.
self.diffusion_step_encoder = nn.Sequential( self.diffusion_step_encoder = nn.Sequential(
DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim), DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim),
nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4), nn.Linear(
config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4
),
nn.Mish(), nn.Mish(),
nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim), nn.Linear(
config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim
),
) )
# The FiLM conditioning dimension. # The FiLM conditioning dimension.
@ -621,10 +666,16 @@ class DiffusionConditionalUnet1d(nn.Module):
self.down_modules.append( self.down_modules.append(
nn.ModuleList( nn.ModuleList(
[ [
DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs), DiffusionConditionalResidualBlock1d(
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs), dim_in, dim_out, **common_res_block_kwargs
),
DiffusionConditionalResidualBlock1d(
dim_out, dim_out, **common_res_block_kwargs
),
# Downsample as long as it is not the last block. # Downsample as long as it is not the last block.
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(), nn.Conv1d(dim_out, dim_out, 3, 2, 1)
if not is_last
else nn.Identity(),
] ]
) )
) )
@ -633,10 +684,14 @@ class DiffusionConditionalUnet1d(nn.Module):
self.mid_modules = nn.ModuleList( self.mid_modules = nn.ModuleList(
[ [
DiffusionConditionalResidualBlock1d( DiffusionConditionalResidualBlock1d(
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs config.down_dims[-1],
config.down_dims[-1],
**common_res_block_kwargs,
), ),
DiffusionConditionalResidualBlock1d( DiffusionConditionalResidualBlock1d(
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs config.down_dims[-1],
config.down_dims[-1],
**common_res_block_kwargs,
), ),
] ]
) )
@ -649,10 +704,16 @@ class DiffusionConditionalUnet1d(nn.Module):
nn.ModuleList( nn.ModuleList(
[ [
# dim_in * 2, because it takes the encoder's skip connection as well # dim_in * 2, because it takes the encoder's skip connection as well
DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs), DiffusionConditionalResidualBlock1d(
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs), dim_in * 2, dim_out, **common_res_block_kwargs
),
DiffusionConditionalResidualBlock1d(
dim_out, dim_out, **common_res_block_kwargs
),
# Upsample as long as it is not the last block. # Upsample as long as it is not the last block.
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(), nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1)
if not is_last
else nn.Identity(),
] ]
) )
) )
@ -726,17 +787,23 @@ class DiffusionConditionalResidualBlock1d(nn.Module):
self.use_film_scale_modulation = use_film_scale_modulation self.use_film_scale_modulation = use_film_scale_modulation
self.out_channels = out_channels self.out_channels = out_channels
self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups) self.conv1 = DiffusionConv1dBlock(
in_channels, out_channels, kernel_size, n_groups=n_groups
)
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale. # FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels)) self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups) self.conv2 = DiffusionConv1dBlock(
out_channels, out_channels, kernel_size, n_groups=n_groups
)
# A final convolution for dimension matching the residual (if needed). # A final convolution for dimension matching the residual (if needed).
self.residual_conv = ( self.residual_conv = (
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() nn.Conv1d(in_channels, out_channels, 1)
if in_channels != out_channels
else nn.Identity()
) )
def forward(self, x: Tensor, cond: Tensor) -> Tensor: def forward(self, x: Tensor, cond: Tensor) -> Tensor:

View File

@ -7,7 +7,9 @@ from torch import Tensor, nn
from .configuration_classifier import ClassifierConfig from .configuration_classifier import ClassifierConfig
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -15,7 +17,10 @@ class ClassifierOutput:
"""Wrapper for classifier outputs with additional metadata.""" """Wrapper for classifier outputs with additional metadata."""
def __init__( def __init__(
self, logits: Tensor, probabilities: Optional[Tensor] = None, hidden_states: Optional[Tensor] = None self,
logits: Tensor,
probabilities: Optional[Tensor] = None,
hidden_states: Optional[Tensor] = None,
): ):
self.logits = logits self.logits = logits
self.probabilities = probabilities self.probabilities = probabilities
@ -43,12 +48,14 @@ class Classifier(
name = "classifier" name = "classifier"
def __init__(self, config: ClassifierConfig): def __init__(self, config: ClassifierConfig):
from transformers import AutoImageProcessor, AutoModel from transformers import AutoModel
super().__init__() super().__init__()
self.config = config self.config = config
# self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True) # self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True) encoder = AutoModel.from_pretrained(
self.config.model_name, trust_remote_code=True
)
# Extract vision model if we're given a multimodal model # Extract vision model if we're given a multimodal model
if hasattr(encoder, "vision_model"): if hasattr(encoder, "vision_model"):
logging.info("Multimodal model detected - using vision encoder only") logging.info("Multimodal model detected - using vision encoder only")
@ -74,7 +81,9 @@ class Classifier(
self.feature_dim = self.encoder.fc.in_features self.feature_dim = self.encoder.fc.in_features
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1]) self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
elif hasattr(self.encoder.config, "hidden_sizes"): elif hasattr(self.encoder.config, "hidden_sizes"):
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension self.feature_dim = self.encoder.config.hidden_sizes[
-1
] # Last channel dimension
else: else:
raise ValueError("Unsupported CNN architecture") raise ValueError("Unsupported CNN architecture")
@ -94,14 +103,19 @@ class Classifier(
if hasattr(self.encoder.config, "hidden_size"): if hasattr(self.encoder.config, "hidden_size"):
input_dim = self.encoder.config.hidden_size input_dim = self.encoder.config.hidden_size
else: else:
raise ValueError("Unsupported transformer architecture since hidden_size is not found") raise ValueError(
"Unsupported transformer architecture since hidden_size is not found"
)
self.classifier_head = nn.Sequential( self.classifier_head = nn.Sequential(
nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim), nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
nn.Dropout(self.config.dropout_rate), nn.Dropout(self.config.dropout_rate),
nn.LayerNorm(self.config.hidden_dim), nn.LayerNorm(self.config.hidden_dim),
nn.ReLU(), nn.ReLU(),
nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes), nn.Linear(
self.config.hidden_dim,
1 if self.config.num_classes == 2 else self.config.num_classes,
),
) )
self.classifier_head = self.classifier_head.to(self.config.device) self.classifier_head = self.classifier_head.to(self.config.device)
@ -127,7 +141,10 @@ class Classifier(
return features return features
else: # Transformer models else: # Transformer models
outputs = self.encoder(processed) outputs = self.encoder(processed)
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None: if (
hasattr(outputs, "pooler_output")
and outputs.pooler_output is not None
):
return outputs.pooler_output return outputs.pooler_output
return outputs.last_hidden_state[:, 0, :] return outputs.last_hidden_state[:, 0, :]
@ -143,7 +160,9 @@ class Classifier(
else: else:
probabilities = torch.softmax(logits, dim=-1) probabilities = torch.softmax(logits, dim=-1)
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs) return ClassifierOutput(
logits=logits, probabilities=probabilities, hidden_states=encoder_outputs
)
def predict_reward(self, x, threshold=0.6): def predict_reward(self, x, threshold=0.6):
if self.config.num_classes == 2: if self.config.num_classes == 2:

View File

@ -59,7 +59,9 @@ class SACPolicy(
config.input_normalization_params config.input_normalization_params
) )
self.normalize_inputs = Normalize( self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, input_normalization_params config.input_shapes,
config.input_normalization_modes,
input_normalization_params,
) )
else: else:
self.normalize_inputs = nn.Identity() self.normalize_inputs = nn.Identity()
@ -90,7 +92,8 @@ class SACPolicy(
ensemble=Ensemble( ensemble=Ensemble(
[ [
CriticHead( CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], input_dim=encoder_critic.output_dim
+ config.output_shapes["action"][0],
**config.critic_network_kwargs, **config.critic_network_kwargs,
) )
for _ in range(config.num_critics) for _ in range(config.num_critics)
@ -104,7 +107,8 @@ class SACPolicy(
ensemble=Ensemble( ensemble=Ensemble(
[ [
CriticHead( CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], input_dim=encoder_critic.output_dim
+ config.output_shapes["action"][0],
**config.critic_network_kwargs, **config.critic_network_kwargs,
) )
for _ in range(config.num_critics) for _ in range(config.num_critics)
@ -120,13 +124,17 @@ class SACPolicy(
self.actor = Policy( self.actor = Policy(
encoder=encoder_actor, encoder=encoder_actor,
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs), network=MLP(
input_dim=encoder_actor.output_dim, **config.actor_network_kwargs
),
action_dim=config.output_shapes["action"][0], action_dim=config.output_shapes["action"][0],
encoder_is_shared=config.shared_encoder, encoder_is_shared=config.shared_encoder,
**config.policy_kwargs, **config.policy_kwargs,
) )
if config.target_entropy is None: if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2) config.target_entropy = (
-np.prod(config.output_shapes["action"][0]) / 2
) # (-dim(A)/2)
# TODO (azouitine): Handle the case where the temparameter is a fixed # TODO (azouitine): Handle the case where the temparameter is a fixed
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise # TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
@ -153,7 +161,11 @@ class SACPolicy(
return actions return actions
def critic_forward( def critic_forward(
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False, observation_features: Tensor | None = None self,
observations: dict[str, Tensor],
actions: Tensor,
use_target: bool = False,
observation_features: Tensor | None = None,
) -> Tensor: ) -> Tensor:
"""Forward pass through a critic network ensemble """Forward pass through a critic network ensemble
@ -173,21 +185,37 @@ class SACPolicy(
def update_target_networks(self): def update_target_networks(self):
"""Update target networks with exponential moving average""" """Update target networks with exponential moving average"""
for target_param, param in zip( for target_param, param in zip(
self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=False self.critic_target.parameters(),
self.critic_ensemble.parameters(),
strict=False,
): ):
target_param.data.copy_( target_param.data.copy_(
param.data * self.config.critic_target_update_weight param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight) + target_param.data * (1.0 - self.config.critic_target_update_weight)
) )
def compute_loss_critic(self, observations, actions, rewards, next_observations, done, observation_features: Tensor | None = None, next_observation_features: Tensor | None = None) -> Tensor: def compute_loss_critic(
self,
observations,
actions,
rewards,
next_observations,
done,
observation_features: Tensor | None = None,
next_observation_features: Tensor | None = None,
) -> Tensor:
temperature = self.log_alpha.exp().item() temperature = self.log_alpha.exp().item()
with torch.no_grad(): with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features) next_action_preds, next_log_probs, _ = self.actor(
next_observations, next_observation_features
)
# 2- compute q targets # 2- compute q targets
q_targets = self.critic_forward( q_targets = self.critic_forward(
observations=next_observations, actions=next_action_preds, use_target=True, observation_features=next_observation_features observations=next_observations,
actions=next_action_preds,
use_target=True,
observation_features=next_observation_features,
) )
# subsample critics to prevent overfitting if use high UTD (update to date) # subsample critics to prevent overfitting if use high UTD (update to date)
@ -204,7 +232,12 @@ class SACPolicy(
td_target = rewards + (1 - done) * self.config.discount * min_q td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs # 3- compute predicted qs
q_preds = self.critic_forward(observations, actions, use_target=False, observation_features=observation_features) q_preds = self.critic_forward(
observations,
actions,
use_target=False,
observation_features=observation_features,
)
# 4- Calculate loss # 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
@ -219,20 +252,31 @@ class SACPolicy(
).sum() ).sum()
return critics_loss return critics_loss
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor: def compute_loss_temperature(
self, observations, observation_features: Tensor | None = None
) -> Tensor:
"""Compute the temperature loss""" """Compute the temperature loss"""
# calculate temperature loss # calculate temperature loss
with torch.no_grad(): with torch.no_grad():
_, log_probs, _ = self.actor(observations, observation_features) _, log_probs, _ = self.actor(observations, observation_features)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean() temperature_loss = (
-self.log_alpha.exp() * (log_probs + self.config.target_entropy)
).mean()
return temperature_loss return temperature_loss
def compute_loss_actor(self, observations, observation_features: Tensor | None = None) -> Tensor: def compute_loss_actor(
self, observations, observation_features: Tensor | None = None
) -> Tensor:
temperature = self.log_alpha.exp().item() temperature = self.log_alpha.exp().item()
actions_pi, log_probs, _ = self.actor(observations, observation_features) actions_pi, log_probs, _ = self.actor(observations, observation_features)
q_preds = self.critic_forward(observations, actions_pi, use_target=False, observation_features=observation_features) q_preds = self.critic_forward(
observations,
actions_pi,
use_target=False,
observation_features=observation_features,
)
min_q_preds = q_preds.min(dim=0)[0] min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((temperature * log_probs) - min_q_preds).mean() actor_loss = ((temperature * log_probs) - min_q_preds).mean()
@ -259,7 +303,11 @@ class MLP(nn.Module):
if dropout_rate is not None and dropout_rate > 0: if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate)) layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[0])) layers.append(nn.LayerNorm(hidden_dims[0]))
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)()) layers.append(
activations
if isinstance(activations, nn.Module)
else getattr(nn, activations)()
)
# Rest of the layers # Rest of the layers
for i in range(1, len(hidden_dims)): for i in range(1, len(hidden_dims)):
@ -270,7 +318,9 @@ class MLP(nn.Module):
layers.append(nn.Dropout(p=dropout_rate)) layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[i])) layers.append(nn.LayerNorm(hidden_dims[i]))
layers.append( layers.append(
activations if isinstance(activations, nn.Module) else getattr(nn, activations)() activations
if isinstance(activations, nn.Module)
else getattr(nn, activations)()
) )
self.net = nn.Sequential(*layers) self.net = nn.Sequential(*layers)
@ -381,7 +431,11 @@ class CriticEnsemble(nn.Module):
actions = self.output_normalization(actions)["action"] actions = self.output_normalization(actions)["action"]
actions = actions.to(device) actions = actions.to(device)
obs_enc = observation_features if observation_features is not None else (observations if self.encoder is None else self.encoder(observations)) obs_enc = (
observation_features
if observation_features is not None
else (observations if self.encoder is None else self.encoder(observations))
)
inputs = torch.cat([obs_enc, actions], dim=-1) inputs = torch.cat([obs_enc, actions], dim=-1)
q_values = self.ensemble(inputs) # [num_critics, B, 1] q_values = self.ensemble(inputs) # [num_critics, B, 1]
@ -445,7 +499,11 @@ class Policy(nn.Module):
observation_features: torch.Tensor | None = None, observation_features: torch.Tensor | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists # Encode observations if encoder exists
obs_enc = observation_features if observation_features is not None else (observations if self.encoder is None else self.encoder(observations)) obs_enc = (
observation_features
if observation_features is not None
else (observations if self.encoder is None else self.encoder(observations))
)
# Get network outputs # Get network outputs
outputs = self.network(obs_enc) outputs = self.network(obs_enc)
@ -454,11 +512,15 @@ class Policy(nn.Module):
# Compute standard deviations # Compute standard deviations
if self.fixed_std is None: if self.fixed_std is None:
log_std = self.std_layer(outputs) log_std = self.std_layer(outputs)
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!" assert not torch.isnan(
log_std
).any(), "[ERROR] log_std became NaN after std_layer!"
if self.use_tanh_squash: if self.use_tanh_squash:
log_std = torch.tanh(log_std) log_std = torch.tanh(log_std)
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0) log_std = self.log_std_min + 0.5 * (
self.log_std_max - self.log_std_min
) * (log_std + 1.0)
else: else:
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
else: else:
@ -471,7 +533,9 @@ class Policy(nn.Module):
if self.use_tanh_squash: if self.use_tanh_squash:
actions = torch.tanh(x_t) actions = torch.tanh(x_t)
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh log_probs -= torch.log(
(1 - actions.pow(2)) + 1e-6
) # Adjust log-probs for Tanh
else: else:
actions = x_t # No Tanh; raw Gaussian sample actions = x_t # No Tanh; raw Gaussian sample
@ -518,12 +582,15 @@ class SACObservationEncoder(nn.Module):
freeze_image_encoder(self.image_enc_layers) freeze_image_encoder(self.image_enc_layers)
else: else:
self.parameters_to_optimize += list(self.image_enc_layers.parameters()) self.parameters_to_optimize += list(self.image_enc_layers.parameters())
self.all_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] self.all_image_keys = [
k for k in config.input_shapes if k.startswith("observation.image")
]
if "observation.state" in config.input_shapes: if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential( self.state_enc_layers = nn.Sequential(
nn.Linear( nn.Linear(
in_features=config.input_shapes["observation.state"][0], out_features=config.latent_dim in_features=config.input_shapes["observation.state"][0],
out_features=config.latent_dim,
), ),
nn.LayerNorm(normalized_shape=config.latent_dim), nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(), nn.Tanh(),
@ -544,7 +611,9 @@ class SACObservationEncoder(nn.Module):
self.aggregation_size += config.latent_dim self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters()) self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim) self.aggregation_layer = nn.Linear(
in_features=self.aggregation_size, out_features=config.latent_dim
)
self.parameters_to_optimize += list(self.aggregation_layer.parameters()) self.parameters_to_optimize += list(self.aggregation_layer.parameters())
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
@ -557,13 +626,19 @@ class SACObservationEncoder(nn.Module):
obs_dict = self.input_normalization(obs_dict) obs_dict = self.input_normalization(obs_dict)
# Batch all images along the batch dimension, then encode them. # Batch all images along the batch dimension, then encode them.
if len(self.all_image_keys) > 0: if len(self.all_image_keys) > 0:
images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0) images_batched = torch.cat(
[obs_dict[key] for key in self.all_image_keys], dim=0
)
images_batched = self.image_enc_layers(images_batched) images_batched = self.image_enc_layers(images_batched)
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys)) embeddings_chunks = torch.chunk(
images_batched, dim=0, chunks=len(self.all_image_keys)
)
feat.extend(embeddings_chunks) feat.extend(embeddings_chunks)
if "observation.environment_state" in self.config.input_shapes: if "observation.environment_state" in self.config.input_shapes:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) feat.append(
self.env_state_enc_layers(obs_dict["observation.environment_state"])
)
if "observation.state" in self.config.input_shapes: if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"])) feat.append(self.state_enc_layers(obs_dict["observation.state"]))
@ -631,7 +706,9 @@ class PretrainedImageEncoder(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config) self.image_enc_layers, self.image_enc_out_shape = (
self._load_pretrained_vision_encoder(config)
)
self.image_enc_proj = nn.Sequential( self.image_enc_proj = nn.Sequential(
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim), nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim), nn.LayerNorm(config.latent_dim),
@ -642,15 +719,21 @@ class PretrainedImageEncoder(nn.Module):
"""Set up CNN encoder""" """Set up CNN encoder"""
from transformers import AutoModel from transformers import AutoModel
self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True) self.image_enc_layers = AutoModel.from_pretrained(
config.vision_encoder_name, trust_remote_code=True
)
# self.image_enc_layers.pooler = Identity() # self.image_enc_layers.pooler = Identity()
if hasattr(self.image_enc_layers.config, "hidden_sizes"): if hasattr(self.image_enc_layers.config, "hidden_sizes"):
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[
-1
] # Last channel dimension
elif hasattr(self.image_enc_layers, "fc"): elif hasattr(self.image_enc_layers, "fc"):
self.image_enc_out_shape = self.image_enc_layers.fc.in_features self.image_enc_out_shape = self.image_enc_layers.fc.in_features
else: else:
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN") raise ValueError(
"Unsupported vision encoder architecture, make sure you are using a CNN"
)
return self.image_enc_layers, self.image_enc_out_shape return self.image_enc_layers, self.image_enc_out_shape
def forward(self, x): def forward(self, x):
@ -673,7 +756,7 @@ def orthogonal_init():
class Identity(nn.Module): class Identity(nn.Module):
def __init__(self): def __init__(self):
super(Identity, self).__init__() super().__init__()
def forward(self, x): def forward(self, x):
return x return x
@ -701,7 +784,9 @@ class Ensemble(nn.Module):
return self.module(*args, **kwargs) return self.module(*args, **kwargs)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return torch.vmap(self._call, (0, None), randomness="different")(self.params, *args, **kwargs) return torch.vmap(self._call, (0, None), randomness="different")(
self.params, *args, **kwargs
)
def __repr__(self): def __repr__(self):
return f"Vectorized {len(self)}x " + self._repr return f"Vectorized {len(self)}x " + self._repr
@ -710,7 +795,9 @@ class Ensemble(nn.Module):
# TODO (azouitine): I think in our case this function is not usefull we should remove it # TODO (azouitine): I think in our case this function is not usefull we should remove it
# after some investigation # after some investigation
# borrowed from tdmpc # borrowed from tdmpc
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: def flatten_forward_unflatten(
fn: Callable[[Tensor], Tensor], image_tensor: Tensor
) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor. """Helper to temporarily flatten extra dims at the start of the image tensor.
Args: Args:
@ -736,7 +823,9 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
for key, value in inner_dict.items(): for key, value in inner_dict.items():
converted_params[outer_key][key] = torch.tensor(value) converted_params[outer_key][key] = torch.tensor(value)
if "image" in outer_key: if "image" in outer_key:
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1) converted_params[outer_key][key] = converted_params[outer_key][
key
].view(3, 1, 1)
return converted_params return converted_params

View File

@ -183,7 +183,9 @@ class TDMPCConfig(PreTrainedConfig):
"If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1." "If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1."
) )
if not self.use_mpc: if not self.use_mpc:
raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.") raise ValueError(
"If `n_action_steps > 1`, `use_mpc` must be set to `True`."
)
if self.n_action_steps > self.horizon: if self.n_action_steps > self.horizon:
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.") raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")

View File

@ -100,7 +100,9 @@ class TDMPCPolicy(PreTrainedPolicy):
""" """
self._queues = { self._queues = {
"observation.state": deque(maxlen=1), "observation.state": deque(maxlen=1),
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)), "action": deque(
maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)
),
} }
if self.config.image_features: if self.config.image_features:
self._queues["observation.image"] = deque(maxlen=1) self._queues["observation.image"] = deque(maxlen=1)
@ -189,7 +191,11 @@ class TDMPCPolicy(PreTrainedPolicy):
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled # In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
# trajectories. # trajectories.
z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples) z = einops.repeat(
z,
"b d -> n b d",
n=self.config.n_gaussian_samples + self.config.n_pi_samples,
)
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization # Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
# algorithm. # algorithm.
@ -211,35 +217,47 @@ class TDMPCPolicy(PreTrainedPolicy):
self.config.action_feature.shape[0], self.config.action_feature.shape[0],
device=std.device, device=std.device,
) )
gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1) gaussian_actions = torch.clamp(
mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1
)
# Compute elite actions. # Compute elite actions.
actions = torch.cat([gaussian_actions, pi_actions], dim=1) actions = torch.cat([gaussian_actions, pi_actions], dim=1)
value = self.estimate_value(z, actions).nan_to_num_(0) value = self.estimate_value(z, actions).nan_to_num_(0)
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch) elite_idxs = torch.topk(
value, self.config.n_elites, dim=0
).indices # (n_elites, batch)
elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch) elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch)
# (horizon, n_elites, batch, action_dim) # (horizon, n_elites, batch, action_dim)
elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1) elite_actions = actions.take_along_dim(
einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1
)
# Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites. # Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites.
max_value = elite_value.max(0, keepdim=True)[0] # (1, batch) max_value = elite_value.max(0, keepdim=True)[0] # (1, batch)
# The weighting is a softmax over trajectory values. Note that this is not the same as the usage # The weighting is a softmax over trajectory values. Note that this is not the same as the usage
# of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This # of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This
# makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²). # makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²).
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value)) score = torch.exp(
self.config.elite_weighting_temperature * (elite_value - max_value)
)
score /= score.sum(axis=0, keepdim=True) score /= score.sum(axis=0, keepdim=True)
# (horizon, batch, action_dim) # (horizon, batch, action_dim)
_mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) _mean = torch.sum(
einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1
)
_std = torch.sqrt( _std = torch.sqrt(
torch.sum( torch.sum(
einops.rearrange(score, "n b -> n b 1") einops.rearrange(score, "n b -> n b 1")
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2, * (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d"))
** 2,
dim=1, dim=1,
) )
) )
# Update mean with an exponential moving average, and std with a direct replacement. # Update mean with an exponential moving average, and std with a direct replacement.
mean = ( mean = (
self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean self.config.gaussian_mean_momentum * mean
+ (1 - self.config.gaussian_mean_momentum) * _mean
) )
std = _std.clamp_(self.config.min_std, self.config.max_std) std = _std.clamp_(self.config.min_std, self.config.max_std)
@ -248,7 +266,9 @@ class TDMPCPolicy(PreTrainedPolicy):
# Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax # Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax
# scores from the last iteration. # scores from the last iteration.
actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)] actions = elite_actions[
:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)
]
return actions return actions
@ -271,7 +291,8 @@ class TDMPCPolicy(PreTrainedPolicy):
# of the FOWM paper. # of the FOWM paper.
if self.config.uncertainty_regularizer_coeff > 0: if self.config.uncertainty_regularizer_coeff > 0:
regularization = -( regularization = -(
self.config.uncertainty_regularizer_coeff * self.model.Qs(z, actions[t]).std(0) self.config.uncertainty_regularizer_coeff
* self.model.Qs(z, actions[t]).std(0)
) )
else: else:
regularization = 0 regularization = 0
@ -291,15 +312,22 @@ class TDMPCPolicy(PreTrainedPolicy):
if self.config.q_ensemble_size > 2: if self.config.q_ensemble_size > 2:
G += ( G += (
running_discount running_discount
* torch.min(terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0)[ * torch.min(
0 terminal_values[
] torch.randint(0, self.config.q_ensemble_size, size=(2,))
],
dim=0,
)[0]
) )
else: else:
G += running_discount * torch.min(terminal_values, dim=0)[0] G += running_discount * torch.min(terminal_values, dim=0)[0]
# Finally, also regularize the terminal value. # Finally, also regularize the terminal value.
if self.config.uncertainty_regularizer_coeff > 0: if self.config.uncertainty_regularizer_coeff > 0:
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0) G -= (
running_discount
* self.config.uncertainty_regularizer_coeff
* terminal_values.std(0)
)
return G return G
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
@ -329,7 +357,10 @@ class TDMPCPolicy(PreTrainedPolicy):
# Apply random image augmentations. # Apply random image augmentations.
if self.config.image_features and self.config.max_random_shift_ratio > 0: if self.config.image_features and self.config.max_random_shift_ratio > 0:
observations["observation.image"] = flatten_forward_unflatten( observations["observation.image"] = flatten_forward_unflatten(
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio), partial(
random_shifts_aug,
max_random_shift_ratio=self.config.max_random_shift_ratio,
),
observations["observation.image"], observations["observation.image"],
) )
@ -347,14 +378,20 @@ class TDMPCPolicy(PreTrainedPolicy):
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action # Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
# gives us a next `z`. # gives us a next `z`.
batch_size = batch["index"].shape[0] batch_size = batch["index"].shape[0]
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device) z_preds = torch.empty(
horizon + 1, batch_size, self.config.latent_dim, device=device
)
z_preds[0] = self.model.encode(current_observation) z_preds[0] = self.model.encode(current_observation)
reward_preds = torch.empty_like(reward, device=device) reward_preds = torch.empty_like(reward, device=device)
for t in range(horizon): for t in range(horizon):
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t]) z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(
z_preds[t], action[t]
)
# Compute Q and V value predictions based on the latent rollout. # Compute Q and V value predictions based on the latent rollout.
q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch) q_preds_ensemble = self.model.Qs(
z_preds[:-1], action
) # (ensemble, horizon, batch)
v_preds = self.model.V(z_preds[:-1]) v_preds = self.model.V(z_preds[:-1])
info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()}) info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()})
@ -368,10 +405,14 @@ class TDMPCPolicy(PreTrainedPolicy):
# actions (not actions estimated by π). # actions (not actions estimated by π).
# Note: Here we do not use self.model_target, but self.model. This is to follow the original code # Note: Here we do not use self.model_target, but self.model. This is to follow the original code
# and the FOWM paper. # and the FOWM paper.
q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations)) q_targets = reward + self.config.discount * self.model.V(
self.model.encode(next_observations)
)
# From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we # From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we
# are using them to compute loss for V. # are using them to compute loss for V.
v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True) v_targets = self.model_target.Qs(
z_preds[:-1].detach(), action, return_min=True
)
# Compute losses. # Compute losses.
# Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the # Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the
@ -414,7 +455,9 @@ class TDMPCPolicy(PreTrainedPolicy):
temporal_loss_coeffs temporal_loss_coeffs
* F.mse_loss( * F.mse_loss(
q_preds_ensemble, q_preds_ensemble,
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]), einops.repeat(
q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]
),
reduction="none", reduction="none",
).sum(0) # sum over ensemble ).sum(0) # sum over ensemble
# `q_preds_ensemble` depends on the first observation and the actions. # `q_preds_ensemble` depends on the first observation and the actions.
@ -452,12 +495,14 @@ class TDMPCPolicy(PreTrainedPolicy):
z_preds = z_preds.detach() z_preds = z_preds.detach()
# Use stopgrad for the advantage calculation. # Use stopgrad for the advantage calculation.
with torch.no_grad(): with torch.no_grad():
advantage = self.model_target.Qs(z_preds[:-1], action, return_min=True) - self.model.V( advantage = self.model_target.Qs(
z_preds[:-1] z_preds[:-1], action, return_min=True
) ) - self.model.V(z_preds[:-1])
info["advantage"] = advantage[0] info["advantage"] = advantage[0]
# (t, b) # (t, b)
exp_advantage = torch.clamp(torch.exp(advantage * self.config.advantage_scaling), max=100.0) exp_advantage = torch.clamp(
torch.exp(advantage * self.config.advantage_scaling), max=100.0
)
action_preds = self.model.pi(z_preds[:-1]) # (t, b, a) action_preds = self.model.pi(z_preds[:-1]) # (t, b, a)
# Calculate the MSE between the actions and the action predictions. # Calculate the MSE between the actions and the action predictions.
# Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation # Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation
@ -511,7 +556,9 @@ class TDMPCPolicy(PreTrainedPolicy):
# Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA # Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA
# update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code # update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995) # we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum) update_ema_parameters(
self.model_target, self.model, self.config.target_model_momentum
)
class TDMPCTOLD(nn.Module): class TDMPCTOLD(nn.Module):
@ -598,7 +645,9 @@ class TDMPCTOLD(nn.Module):
"Sanity check. The last linear layer needs 0 initialization on weights." "Sanity check. The last linear layer needs 0 initialization on weights."
) )
nn.init.zeros_(m[-1].weight) nn.init.zeros_(m[-1].weight)
nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure nn.init.zeros_(
m[-1].bias
) # this has already been done, but keep this line here for good measure
def encode(self, obs: dict[str, Tensor]) -> Tensor: def encode(self, obs: dict[str, Tensor]) -> Tensor:
"""Encodes an observation into its latent representation.""" """Encodes an observation into its latent representation."""
@ -702,11 +751,26 @@ class TDMPCObservationEncoder(nn.Module):
stride=2, stride=2,
), ),
nn.ReLU(), nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2), nn.Conv2d(
config.image_encoder_hidden_dim,
config.image_encoder_hidden_dim,
5,
stride=2,
),
nn.ReLU(), nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), nn.Conv2d(
config.image_encoder_hidden_dim,
config.image_encoder_hidden_dim,
3,
stride=2,
),
nn.ReLU(), nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), nn.Conv2d(
config.image_encoder_hidden_dim,
config.image_encoder_hidden_dim,
3,
stride=2,
),
nn.ReLU(), nn.ReLU(),
) )
dummy_shape = (1, *next(iter(config.image_features.values())).shape) dummy_shape = (1, *next(iter(config.image_features.values())).shape)
@ -796,12 +860,17 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
"""Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param.""" """Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param."""
for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True): for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True):
for (n_p_ema, p_ema), (n_p, p) in zip( for (n_p_ema, p_ema), (n_p, p) in zip(
ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True ema_module.named_parameters(recurse=False),
module.named_parameters(recurse=False),
strict=True,
): ):
assert n_p_ema == n_p, "Parameter names don't match for EMA model update" assert n_p_ema == n_p, "Parameter names don't match for EMA model update"
if isinstance(p, dict): if isinstance(p, dict):
raise RuntimeError("Dict parameter not supported") raise RuntimeError("Dict parameter not supported")
if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad: if (
isinstance(module, nn.modules.batchnorm._BatchNorm)
or not p.requires_grad
):
# Copy BatchNorm parameters, and non-trainable parameters directly. # Copy BatchNorm parameters, and non-trainable parameters directly.
p_ema.copy_(p.to(dtype=p_ema.dtype).data) p_ema.copy_(p.to(dtype=p_ema.dtype).data)
with torch.no_grad(): with torch.no_grad():
@ -809,7 +878,9 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha) p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha)
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: def flatten_forward_unflatten(
fn: Callable[[Tensor], Tensor], image_tensor: Tensor
) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor. """Helper to temporarily flatten extra dims at the start of the image tensor.
Args: Args:

View File

@ -145,8 +145,14 @@ class VQBeTPolicy(PreTrainedPolicy):
) )
if len(self._queues["action"]) == 0: if len(self._queues["action"]) == 0:
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} batch = {
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size] k: torch.stack(list(self._queues[k]), dim=1)
for k in batch
if k in self._queues
}
actions = self.vqbet(batch, rollout=True)[
:, : self.config.action_chunk_size
]
# the dimension of returned action is (batch_size, action_chunk_size, action_dim) # the dimension of returned action is (batch_size, action_chunk_size, action_dim)
actions = self.unnormalize_outputs({"action": actions})["action"] actions = self.unnormalize_outputs({"action": actions})["action"]
@ -168,7 +174,9 @@ class VQBeTPolicy(PreTrainedPolicy):
# n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`. # n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`.
# n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree). # n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree).
loss, n_different_codes, n_different_combinations, recon_l1_error = ( loss, n_different_codes, n_different_combinations, recon_l1_error = (
self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"]) self.vqbet.action_head.discretize(
self.config.n_vqvae_training_steps, batch["action"]
)
) )
return loss, { return loss, {
"n_different_codes": n_different_codes, "n_different_codes": n_different_codes,
@ -225,7 +233,9 @@ class SpatialSoftmax(nn.Module):
# we could use torch.linspace directly but that seems to behave slightly differently than numpy # we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and causes a small degradation in pc_success of pre-trained models. # and causes a small degradation in pc_success of pre-trained models.
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)) pos_x, pos_y = np.meshgrid(
np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)
)
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float() pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float() pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
# register as buffer so it's moved to the correct device. # register as buffer so it's moved to the correct device.
@ -339,7 +349,12 @@ class VQBeTModel(nn.Module):
num_tokens = self.config.n_action_pred_token + self.config.n_obs_steps - 1 num_tokens = self.config.n_action_pred_token + self.config.n_obs_steps - 1
self.register_buffer( self.register_buffer(
"select_target_actions_indices", "select_target_actions_indices",
torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]), torch.row_stack(
[
torch.arange(i, i + self.config.action_chunk_size)
for i in range(num_tokens)
]
),
) )
def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]: def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]:
@ -354,7 +369,11 @@ class VQBeTModel(nn.Module):
) )
# Separate batch and sequence dims. # Separate batch and sequence dims.
img_features = einops.rearrange( img_features = einops.rearrange(
img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images img_features,
"(b s n) ... -> b s n ...",
b=batch_size,
s=n_obs_steps,
n=self.num_images,
) )
# Arrange prior and current observation step tokens as shown in the class docstring. # Arrange prior and current observation step tokens as shown in the class docstring.
@ -366,13 +385,19 @@ class VQBeTModel(nn.Module):
input_tokens.append( input_tokens.append(
self.state_projector(batch["observation.state"]) self.state_projector(batch["observation.state"])
) # (batch, obs_step, projection dims) ) # (batch, obs_step, projection dims)
input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps)) input_tokens.append(
einops.repeat(
self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps
)
)
# Interleave tokens by stacking and rearranging. # Interleave tokens by stacking and rearranging.
input_tokens = torch.stack(input_tokens, dim=2) input_tokens = torch.stack(input_tokens, dim=2)
input_tokens = einops.rearrange(input_tokens, "b n t d -> b (n t) d") input_tokens = einops.rearrange(input_tokens, "b n t d -> b (n t) d")
len_additional_action_token = self.config.n_action_pred_token - 1 len_additional_action_token = self.config.n_action_pred_token - 1
future_action_tokens = self.action_token.repeat(batch_size, len_additional_action_token, 1) future_action_tokens = self.action_token.repeat(
batch_size, len_additional_action_token, 1
)
# add additional action query tokens for predicting future action chunks # add additional action query tokens for predicting future action chunks
input_tokens = torch.cat([input_tokens, future_action_tokens], dim=1) input_tokens = torch.cat([input_tokens, future_action_tokens], dim=1)
@ -391,7 +416,11 @@ class VQBeTModel(nn.Module):
# Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional). # Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional).
if len_additional_action_token > 0: if len_additional_action_token > 0:
features = torch.cat( features = torch.cat(
[features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1 [
features[:, historical_act_pred_index],
features[:, -len_additional_action_token:],
],
dim=1,
) )
else: else:
features = features[:, historical_act_pred_index] features = features[:, historical_act_pred_index]
@ -399,13 +428,15 @@ class VQBeTModel(nn.Module):
action_head_output = self.action_head(features) action_head_output = self.action_head(features)
# if rollout, VQ-BeT don't calculate loss # if rollout, VQ-BeT don't calculate loss
if rollout: if rollout:
return action_head_output["predicted_action"][:, n_obs_steps - 1, :].reshape( return action_head_output["predicted_action"][
batch_size, self.config.action_chunk_size, -1 :, n_obs_steps - 1, :
) ].reshape(batch_size, self.config.action_chunk_size, -1)
# else, it calculate overall loss (bin prediction loss, and offset loss) # else, it calculate overall loss (bin prediction loss, and offset loss)
else: else:
output = batch["action"][:, self.select_target_actions_indices] output = batch["action"][:, self.select_target_actions_indices]
loss = self.action_head.loss_fn(action_head_output, output, reduction="mean") loss = self.action_head.loss_fn(
action_head_output, output, reduction="mean"
)
return action_head_output, loss return action_head_output, loss
@ -440,7 +471,9 @@ class VQBeTHead(nn.Module):
else: else:
self.map_to_cbet_preds_bin = MLP( self.map_to_cbet_preds_bin = MLP(
in_channels=config.gpt_output_dim, in_channels=config.gpt_output_dim,
hidden_channels=[self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed], hidden_channels=[
self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed
],
) )
self.map_to_cbet_preds_offset = MLP( self.map_to_cbet_preds_offset = MLP(
in_channels=config.gpt_output_dim, in_channels=config.gpt_output_dim,
@ -467,7 +500,10 @@ class VQBeTHead(nn.Module):
loss, metric = self.vqvae_model.vqvae_forward(actions) loss, metric = self.vqvae_model.vqvae_forward(actions)
n_different_codes = sum( n_different_codes = sum(
[len(torch.unique(metric[2][:, i])) for i in range(self.vqvae_model.vqvae_num_layers)] [
len(torch.unique(metric[2][:, i]))
for i in range(self.vqvae_model.vqvae_num_layers)
]
) )
n_different_combinations = len(torch.unique(metric[2], dim=0)) n_different_combinations = len(torch.unique(metric[2], dim=0))
recon_l1_error = metric[0].detach().cpu().item() recon_l1_error = metric[0].detach().cpu().item()
@ -514,7 +550,13 @@ class VQBeTHead(nn.Module):
cbet_secondary_logits = self.map_to_cbet_preds_secondary_bin( cbet_secondary_logits = self.map_to_cbet_preds_secondary_bin(
torch.cat( torch.cat(
(x, F.one_hot(sampled_primary_centers, num_classes=self.config.vqvae_n_embed)), (
x,
F.one_hot(
sampled_primary_centers,
num_classes=self.config.vqvae_n_embed,
),
),
axis=1, axis=1,
) )
) )
@ -522,19 +564,29 @@ class VQBeTHead(nn.Module):
cbet_secondary_logits / self.config.bet_softmax_temperature, dim=-1 cbet_secondary_logits / self.config.bet_softmax_temperature, dim=-1
) )
sampled_secondary_centers = einops.rearrange( sampled_secondary_centers = einops.rearrange(
torch.multinomial(cbet_secondary_probs.view(-1, choices), num_samples=1), torch.multinomial(
cbet_secondary_probs.view(-1, choices), num_samples=1
),
"(NT) 1 -> NT", "(NT) 1 -> NT",
NT=NT, NT=NT,
) )
sampled_centers = torch.stack((sampled_primary_centers, sampled_secondary_centers), axis=1) sampled_centers = torch.stack(
cbet_logits = torch.stack([cbet_primary_logits, cbet_secondary_logits], dim=1) (sampled_primary_centers, sampled_secondary_centers), axis=1
)
cbet_logits = torch.stack(
[cbet_primary_logits, cbet_secondary_logits], dim=1
)
# if self.config.sequentially_select is False, bin prediction head samples primary and secondary code at once. # if self.config.sequentially_select is False, bin prediction head samples primary and secondary code at once.
else: else:
cbet_logits = self.map_to_cbet_preds_bin(x) cbet_logits = self.map_to_cbet_preds_bin(x)
cbet_logits = einops.rearrange( cbet_logits = einops.rearrange(
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers cbet_logits,
"(NT) (G C) -> (NT) G C",
G=self.vqvae_model.vqvae_num_layers,
)
cbet_probs = torch.softmax(
cbet_logits / self.config.bet_softmax_temperature, dim=-1
) )
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
NT, G, choices = cbet_probs.shape NT, G, choices = cbet_probs.shape
sampled_centers = einops.rearrange( sampled_centers = einops.rearrange(
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1), torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
@ -554,9 +606,17 @@ class VQBeTHead(nn.Module):
sampled_offsets = sampled_offsets.sum(dim=1) sampled_offsets = sampled_offsets.sum(dim=1)
with torch.no_grad(): with torch.no_grad():
# Get the centroids (= vectors corresponding to the codes) of each layer to pass it through RVQ decoder # Get the centroids (= vectors corresponding to the codes) of each layer to pass it through RVQ decoder
return_decoder_input = self.vqvae_model.get_embeddings_from_code(sampled_centers).clone().detach() return_decoder_input = (
self.vqvae_model.get_embeddings_from_code(sampled_centers)
.clone()
.detach()
)
# pass the centroids through decoder to get actions. # pass the centroids through decoder to get actions.
decoded_action = self.vqvae_model.get_action_from_latent(return_decoder_input).clone().detach() decoded_action = (
self.vqvae_model.get_action_from_latent(return_decoder_input)
.clone()
.detach()
)
# reshaped extracted offset to match with decoded centroids # reshaped extracted offset to match with decoded centroids
sampled_offsets = einops.rearrange( sampled_offsets = einops.rearrange(
sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size
@ -605,7 +665,9 @@ class VQBeTHead(nn.Module):
# Figure out the loss for the actions. # Figure out the loss for the actions.
# First, we need to find the closest cluster center for each ground truth action. # First, we need to find the closest cluster center for each ground truth action.
with torch.no_grad(): with torch.no_grad():
state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G state_vq, action_bins = self.vqvae_model.get_code(
action_seq
) # action_bins: NT, G
# Now we can compute the loss. # Now we can compute the loss.
@ -628,8 +690,12 @@ class VQBeTHead(nn.Module):
+ cbet_loss2 * self.config.secondary_code_loss_weight + cbet_loss2 * self.config.secondary_code_loss_weight
) )
equal_primary_code_rate = torch.sum((action_bins[:, 0] == sampled_centers[:, 0]).int()) / (NT) equal_primary_code_rate = torch.sum(
equal_secondary_code_rate = torch.sum((action_bins[:, 1] == sampled_centers[:, 1]).int()) / (NT) (action_bins[:, 0] == sampled_centers[:, 0]).int()
) / (NT)
equal_secondary_code_rate = torch.sum(
(action_bins[:, 1] == sampled_centers[:, 1]).int()
) / (NT)
action_mse_error = torch.mean((action_seq - predicted_action) ** 2) action_mse_error = torch.mean((action_seq - predicted_action) ** 2)
vq_action_error = torch.mean(torch.abs(action_seq - decoded_action)) vq_action_error = torch.mean(torch.abs(action_seq - decoded_action))
@ -643,7 +709,9 @@ class VQBeTHead(nn.Module):
"classification_loss": cbet_loss.detach().cpu().item(), "classification_loss": cbet_loss.detach().cpu().item(),
"offset_loss": offset_loss.detach().cpu().item(), "offset_loss": offset_loss.detach().cpu().item(),
"equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(), "equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(),
"equal_secondary_code_rate": equal_secondary_code_rate.detach().cpu().item(), "equal_secondary_code_rate": equal_secondary_code_rate.detach()
.cpu()
.item(),
"vq_action_error": vq_action_error.detach().cpu().item(), "vq_action_error": vq_action_error.detach().cpu().item(),
"offset_action_error": offset_action_error.detach().cpu().item(), "offset_action_error": offset_action_error.detach().cpu().item(),
"action_error_max": action_error_max.detach().cpu().item(), "action_error_max": action_error_max.detach().cpu().item(),
@ -668,7 +736,9 @@ class VQBeTRgbEncoder(nn.Module):
# Always use center crop for eval # Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape) self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
if config.crop_is_random: if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape) self.maybe_random_crop = torchvision.transforms.RandomCrop(
config.crop_shape
)
else: else:
self.maybe_random_crop = self.center_crop self.maybe_random_crop = self.center_crop
else: else:
@ -689,7 +759,9 @@ class VQBeTRgbEncoder(nn.Module):
self.backbone = _replace_submodules( self.backbone = _replace_submodules(
root_module=self.backbone, root_module=self.backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d), predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features), func=lambda x: nn.GroupNorm(
num_groups=x.num_features // 16, num_channels=x.num_features
),
) )
# Set up pooling and final layers. # Set up pooling and final layers.
@ -730,7 +802,9 @@ class VQBeTRgbEncoder(nn.Module):
def _replace_submodules( def _replace_submodules(
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module] root_module: nn.Module,
predicate: Callable[[nn.Module], bool],
func: Callable[[nn.Module], nn.Module],
) -> nn.Module: ) -> nn.Module:
""" """
Args: Args:
@ -743,7 +817,11 @@ def _replace_submodules(
if predicate(root_module): if predicate(root_module):
return func(root_module) return func(root_module)
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] replace_list = [
k.split(".")
for k, m in root_module.named_modules(remove_duplicate=True)
if predicate(m)
]
for *parents, k in replace_list: for *parents, k in replace_list:
parent_module = root_module parent_module = root_module
if len(parents) > 0: if len(parents) > 0:
@ -758,7 +836,9 @@ def _replace_submodules(
else: else:
setattr(parent_module, k, tgt_module) setattr(parent_module, k, tgt_module)
# verify that all BN are replaced # verify that all BN are replaced
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)) assert not any(
predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)
)
return root_module return root_module

View File

@ -123,9 +123,15 @@ class CausalSelfAttention(nn.Module):
# calculate query, key, values for all heads in batch and move head forward to be the batch dim # calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.gpt_hidden_dim, dim=2) q, k, v = self.c_attn(x).split(self.gpt_hidden_dim, dim=2)
k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs) k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs) 1, 2
v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs) ) # (B, nh, T, hs)
q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
1, 2
) # (B, nh, T, hs)
v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
1, 2
) # (B, nh, T, hs)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
@ -133,7 +139,9 @@ class CausalSelfAttention(nn.Module):
att = F.softmax(att, dim=-1) att = F.softmax(att, dim=-1)
att = self.attn_dropout(att) att = self.attn_dropout(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side y = (
y.transpose(1, 2).contiguous().view(B, T, C)
) # re-assemble all head outputs side by side
# output projection # output projection
y = self.resid_dropout(self.c_proj(y)) y = self.resid_dropout(self.c_proj(y))
@ -189,12 +197,16 @@ class GPT(nn.Module):
"ln_f": nn.LayerNorm(config.gpt_hidden_dim), "ln_f": nn.LayerNorm(config.gpt_hidden_dim),
} }
) )
self.lm_head = nn.Linear(config.gpt_hidden_dim, config.gpt_output_dim, bias=False) self.lm_head = nn.Linear(
config.gpt_hidden_dim, config.gpt_output_dim, bias=False
)
# init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
self.apply(self._init_weights) self.apply(self._init_weights)
for pn, p in self.named_parameters(): for pn, p in self.named_parameters():
if pn.endswith("c_proj.weight"): if pn.endswith("c_proj.weight"):
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer)) torch.nn.init.normal_(
p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer)
)
# report number of parameters # report number of parameters
n_params = sum(p.numel() for p in self.parameters()) n_params = sum(p.numel() for p in self.parameters())
@ -208,11 +220,17 @@ class GPT(nn.Module):
) )
# positional encodings that are added to the input embeddings # positional encodings that are added to the input embeddings
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(
0
) # shape (1, t)
# forward the GPT model itself # forward the GPT model itself
tok_emb = self.transformer.wte(input) # token embeddings of shape (b, t, gpt_hidden_dim) tok_emb = self.transformer.wte(
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim) input
) # token embeddings of shape (b, t, gpt_hidden_dim)
pos_emb = self.transformer.wpe(
pos
) # position embeddings of shape (1, t, gpt_hidden_dim)
x = self.transformer.drop(tok_emb + pos_emb) x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h: for block in self.transformer.h:
x = block(x) x = block(x)
@ -237,7 +255,9 @@ class GPT(nn.Module):
# but want to use a smaller block size for some smaller, simpler model # but want to use a smaller block size for some smaller, simpler model
assert gpt_block_size <= self.config.gpt_block_size assert gpt_block_size <= self.config.gpt_block_size
self.config.gpt_block_size = gpt_block_size self.config.gpt_block_size = gpt_block_size
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:gpt_block_size]) self.transformer.wpe.weight = nn.Parameter(
self.transformer.wpe.weight[:gpt_block_size]
)
for block in self.transformer.h: for block in self.transformer.h:
block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size] block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size]
@ -270,7 +290,9 @@ class GPT(nn.Module):
param_dict = dict(self.named_parameters()) param_dict = dict(self.named_parameters())
inter_params = decay & no_decay inter_params = decay & no_decay
union_params = decay | no_decay union_params = decay | no_decay
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format( assert (
len(inter_params) == 0
), "parameters {} made it into both decay/no_decay sets!".format(
str(inter_params) str(inter_params)
) )
assert len(param_dict.keys() - union_params) == 0, ( assert len(param_dict.keys() - union_params) == 0, (
@ -368,8 +390,12 @@ class ResidualVQ(nn.Module):
codebook_input_dim = codebook_dim * heads codebook_input_dim = codebook_dim * heads
requires_projection = codebook_input_dim != dim requires_projection = codebook_input_dim != dim
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() self.project_in = (
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
)
self.num_quantizers = num_quantizers self.num_quantizers = num_quantizers
@ -377,7 +403,10 @@ class ResidualVQ(nn.Module):
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
VectorQuantize( VectorQuantize(
dim=codebook_dim, codebook_dim=codebook_dim, accept_image_fmap=accept_image_fmap, **kwargs dim=codebook_dim,
codebook_dim=codebook_dim,
accept_image_fmap=accept_image_fmap,
**kwargs,
) )
for _ in range(num_quantizers) for _ in range(num_quantizers)
] ]
@ -448,7 +477,9 @@ class ResidualVQ(nn.Module):
return all_codes return all_codes
def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp=None): def forward(
self, x, indices=None, return_all_codes=False, sample_codebook_temp=None
):
""" """
For given input tensor x, this function will return the quantized output, the indices of the quantized output, and the loss. For given input tensor x, this function will return the quantized output, the indices of the quantized output, and the loss.
First, the input tensor x is projected to the codebook dimension. Then, the input tensor x is passed through Nq layers of VectorQuantize. First, the input tensor x is projected to the codebook dimension. Then, the input tensor x is passed through Nq layers of VectorQuantize.
@ -477,13 +508,17 @@ class ResidualVQ(nn.Module):
) )
ce_losses = [] ce_losses = []
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss should_quantize_dropout = (
self.training and self.quantize_dropout and not return_loss
)
# sample a layer index at which to dropout further residual quantization # sample a layer index at which to dropout further residual quantization
# also prepare null indices and loss # also prepare null indices and loss
if should_quantize_dropout: if should_quantize_dropout:
rand_quantize_dropout_index = randrange(self.quantize_dropout_cutoff_index, num_quant) rand_quantize_dropout_index = randrange(
self.quantize_dropout_cutoff_index, num_quant
)
if quant_dropout_multiple_of != 1: if quant_dropout_multiple_of != 1:
rand_quantize_dropout_index = ( rand_quantize_dropout_index = (
@ -492,14 +527,23 @@ class ResidualVQ(nn.Module):
- 1 - 1
) )
null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2]) null_indices_shape = (
null_indices = torch.full(null_indices_shape, -1.0, device=device, dtype=torch.long) (x.shape[0], *x.shape[-2:])
if self.accept_image_fmap
else tuple(x.shape[:2])
)
null_indices = torch.full(
null_indices_shape, -1.0, device=device, dtype=torch.long
)
null_loss = torch.full((1,), 0.0, device=device, dtype=x.dtype) null_loss = torch.full((1,), 0.0, device=device, dtype=x.dtype)
# go through the layers # go through the layers
for quantizer_index, layer in enumerate(self.layers): for quantizer_index, layer in enumerate(self.layers):
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index: if (
should_quantize_dropout
and quantizer_index > rand_quantize_dropout_index
):
all_indices.append(null_indices) all_indices.append(null_indices)
all_losses.append(null_loss) all_losses.append(null_loss)
continue continue
@ -539,7 +583,9 @@ class ResidualVQ(nn.Module):
# stack all losses and indices # stack all losses and indices
all_losses, all_indices = map(partial(torch.stack, dim=-1), (all_losses, all_indices)) all_losses, all_indices = map(
partial(torch.stack, dim=-1), (all_losses, all_indices)
)
ret = (quantized_out, all_indices, all_losses) ret = (quantized_out, all_indices, all_losses)
@ -599,8 +645,12 @@ class VectorQuantize(nn.Module):
codebook_input_dim = codebook_dim * heads codebook_input_dim = codebook_dim * heads
requires_projection = codebook_input_dim != dim requires_projection = codebook_input_dim != dim
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() self.project_in = (
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
)
self.eps = eps self.eps = eps
self.commitment_weight = commitment_weight self.commitment_weight = commitment_weight
@ -614,10 +664,14 @@ class VectorQuantize(nn.Module):
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
assert not (ema_update and learnable_codebook), "learnable codebook not compatible with EMA update" assert not (
ema_update and learnable_codebook
), "learnable codebook not compatible with EMA update"
assert 0 <= sync_update_v <= 1.0 assert 0 <= sync_update_v <= 1.0
assert not (sync_update_v > 0.0 and not learnable_codebook), "learnable codebook must be turned on" assert not (
sync_update_v > 0.0 and not learnable_codebook
), "learnable codebook must be turned on"
self.sync_update_v = sync_update_v self.sync_update_v = sync_update_v
@ -629,7 +683,9 @@ class VectorQuantize(nn.Module):
) )
if sync_codebook is None: if sync_codebook is None:
sync_codebook = distributed.is_initialized() and distributed.get_world_size() > 1 sync_codebook = (
distributed.is_initialized() and distributed.get_world_size() > 1
)
codebook_kwargs = { codebook_kwargs = {
"dim": codebook_dim, "dim": codebook_dim,
@ -794,11 +850,17 @@ class VectorQuantize(nn.Module):
# quantize again # quantize again
quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs) quantize, embed_ind, distances = self._codebook(
x, **codebook_forward_kwargs
)
if self.training: if self.training:
# determine code to use for commitment loss # determine code to use for commitment loss
maybe_detach = torch.detach if not self.learnable_codebook or freeze_codebook else identity maybe_detach = (
torch.detach
if not self.learnable_codebook or freeze_codebook
else identity
)
commit_quantize = maybe_detach(quantize) commit_quantize = maybe_detach(quantize)
@ -808,7 +870,9 @@ class VectorQuantize(nn.Module):
if self.sync_update_v > 0.0: if self.sync_update_v > 0.0:
# (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf # (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
quantize = quantize + self.sync_update_v * (quantize - quantize.detach()) quantize = quantize + self.sync_update_v * (
quantize - quantize.detach()
)
# function for calculating cross entropy loss to distance matrix # function for calculating cross entropy loss to distance matrix
# used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss # used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss
@ -841,7 +905,9 @@ class VectorQuantize(nn.Module):
embed_ind = rearrange(embed_ind, "1 (b h) n -> b n h", h=heads) embed_ind = rearrange(embed_ind, "1 (b h) n -> b n h", h=heads)
if self.accept_image_fmap: if self.accept_image_fmap:
embed_ind = rearrange(embed_ind, "b (h w) ... -> b h w ...", h=height, w=width) embed_ind = rearrange(
embed_ind, "b (h w) ... -> b h w ...", h=height, w=width
)
if only_one: if only_one:
embed_ind = rearrange(embed_ind, "b 1 -> b") embed_ind = rearrange(embed_ind, "b 1 -> b")
@ -895,8 +961,12 @@ class VectorQuantize(nn.Module):
num_codes = codebook.shape[-2] num_codes = codebook.shape[-2]
if (self.orthogonal_reg_max_codes is not None) and num_codes > self.orthogonal_reg_max_codes: if (
rand_ids = torch.randperm(num_codes, device=device)[: self.orthogonal_reg_max_codes] self.orthogonal_reg_max_codes is not None
) and num_codes > self.orthogonal_reg_max_codes:
rand_ids = torch.randperm(num_codes, device=device)[
: self.orthogonal_reg_max_codes
]
codebook = codebook[:, rand_ids] codebook = codebook[:, rand_ids]
orthogonal_reg_loss = orthogonal_loss_fn(codebook) orthogonal_reg_loss = orthogonal_loss_fn(codebook)
@ -928,7 +998,9 @@ class VectorQuantize(nn.Module):
# if masking, only return quantized for where mask has True # if masking, only return quantized for where mask has True
if mask is not None: if mask is not None:
quantize = torch.where(rearrange(mask, "... -> ... 1"), quantize, orig_input) quantize = torch.where(
rearrange(mask, "... -> ... 1"), quantize, orig_input
)
return quantize, embed_ind, loss return quantize, embed_ind, loss
@ -1038,7 +1110,9 @@ def sample_vectors(samples, num):
def batched_sample_vectors(samples, num): def batched_sample_vectors(samples, num):
return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0) return torch.stack(
[sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0
)
def pad_shape(shape, size, dim=0): def pad_shape(shape, size, dim=0):
@ -1089,7 +1163,9 @@ def sample_vectors_distributed(local_samples, num):
all_num_samples = all_gather_sizes(local_samples, dim=0) all_num_samples = all_gather_sizes(local_samples, dim=0)
if rank == 0: if rank == 0:
samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum()) samples_per_rank = sample_multinomial(
num, all_num_samples / all_num_samples.sum()
)
else: else:
samples_per_rank = torch.empty_like(all_num_samples) samples_per_rank = torch.empty_like(all_num_samples)
@ -1202,7 +1278,9 @@ class EuclideanCodebook(nn.Module):
self.eps = eps self.eps = eps
self.threshold_ema_dead_code = threshold_ema_dead_code self.threshold_ema_dead_code = threshold_ema_dead_code
self.reset_cluster_size = ( self.reset_cluster_size = (
reset_cluster_size if (reset_cluster_size is not None) else threshold_ema_dead_code reset_cluster_size
if (reset_cluster_size is not None)
else threshold_ema_dead_code
) )
assert callable(gumbel_sample) assert callable(gumbel_sample)
@ -1213,8 +1291,14 @@ class EuclideanCodebook(nn.Module):
"kmeans init is not compatible with multiple codebooks in distributed environment for now" "kmeans init is not compatible with multiple codebooks in distributed environment for now"
) )
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors self.sample_fn = (
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop sample_vectors_distributed
if use_ddp and sync_kmeans
else batched_sample_vectors
)
self.kmeans_all_reduce_fn = (
distributed.all_reduce if use_ddp and sync_kmeans else noop
)
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
self.register_buffer("initted", torch.Tensor([not kmeans_init])) self.register_buffer("initted", torch.Tensor([not kmeans_init]))
@ -1353,7 +1437,9 @@ class EuclideanCodebook(nn.Module):
distributed.all_reduce(variance_number) distributed.all_reduce(variance_number)
batch_variance = variance_number / num_vectors batch_variance = variance_number / num_vectors
self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay) self.update_with_decay(
"batch_variance", batch_variance, self.affine_param_batch_decay
)
def replace(self, batch_samples, batch_mask): def replace(self, batch_samples, batch_mask):
for ind, (samples, mask) in enumerate( for ind, (samples, mask) in enumerate(
@ -1362,7 +1448,9 @@ class EuclideanCodebook(nn.Module):
if not torch.any(mask): if not torch.any(mask):
continue continue
sampled = self.sample_fn(rearrange(samples, "... -> 1 ..."), mask.sum().item()) sampled = self.sample_fn(
rearrange(samples, "... -> 1 ..."), mask.sum().item()
)
sampled = rearrange(sampled, "1 ... -> ...") sampled = rearrange(sampled, "1 ... -> ...")
self.embed.data[ind][mask] = sampled self.embed.data[ind][mask] = sampled
@ -1386,7 +1474,9 @@ class EuclideanCodebook(nn.Module):
def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False): def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
needs_codebook_dim = x.ndim < 4 needs_codebook_dim = x.ndim < 4
sample_codebook_temp = ( sample_codebook_temp = (
sample_codebook_temp if (sample_codebook_temp is not None) else self.sample_codebook_temp sample_codebook_temp
if (sample_codebook_temp is not None)
else self.sample_codebook_temp
) )
x = x.float() x = x.float()
@ -1414,7 +1504,9 @@ class EuclideanCodebook(nn.Module):
if self.affine_param: if self.affine_param:
codebook_std = self.codebook_variance.clamp(min=1e-5).sqrt() codebook_std = self.codebook_variance.clamp(min=1e-5).sqrt()
batch_std = self.batch_variance.clamp(min=1e-5).sqrt() batch_std = self.batch_variance.clamp(min=1e-5).sqrt()
embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean embed = (embed - self.codebook_mean) * (
batch_std / codebook_std
) + self.batch_mean
dist = -cdist(flatten, embed) dist = -cdist(flatten, embed)
@ -1432,7 +1524,9 @@ class EuclideanCodebook(nn.Module):
if self.training and self.ema_update and not freeze_codebook: if self.training and self.ema_update and not freeze_codebook:
if self.affine_param: if self.affine_param:
flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean flatten = (flatten - self.batch_mean) * (
codebook_std / batch_std
) + self.codebook_mean
if mask is not None: if mask is not None:
embed_onehot[~mask] = 0.0 embed_onehot[~mask] = 0.0
@ -1455,7 +1549,9 @@ class EuclideanCodebook(nn.Module):
self.expire_codes_(x) self.expire_codes_(x)
if needs_codebook_dim: if needs_codebook_dim:
quantize, embed_ind = tuple(rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind)) quantize, embed_ind = tuple(
rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind)
)
dist = unpack_one(dist, ps, "h * d") dist = unpack_one(dist, ps, "h * d")

View File

@ -79,7 +79,9 @@ def save_image(img_array, serial_number, frame_index, images_dir):
img.save(str(path), quality=100) img.save(str(path), quality=100)
logging.info(f"Saved image: {path}") logging.info(f"Saved image: {path}")
except Exception as e: except Exception as e:
logging.error(f"Failed to save image for camera {serial_number} frame {frame_index}: {e}") logging.error(
f"Failed to save image for camera {serial_number} frame {frame_index}: {e}"
)
def save_images_from_cameras( def save_images_from_cameras(
@ -157,7 +159,9 @@ def save_images_from_cameras(
if time.perf_counter() - start_time > record_time_s: if time.perf_counter() - start_time > record_time_s:
break break
print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}") print(
f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}"
)
frame_index += 1 frame_index += 1
finally: finally:
@ -275,7 +279,9 @@ class IntelRealSenseCamera:
f"Multiple {name} cameras have been detected. Please use their serial number to instantiate them." f"Multiple {name} cameras have been detected. Please use their serial number to instantiate them."
) )
name_to_serial_dict = {cam["name"]: cam["serial_number"] for cam in camera_infos} name_to_serial_dict = {
cam["name"]: cam["serial_number"] for cam in camera_infos
}
cam_sn = name_to_serial_dict[name] cam_sn = name_to_serial_dict[name]
return cam_sn return cam_sn
@ -339,7 +345,9 @@ class IntelRealSenseCamera:
actual_height = color_profile.height() actual_height = color_profile.height()
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30) # Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3): if self.fps is not None and not math.isclose(
self.fps, actual_fps, rel_tol=1e-3
):
# Using `OSError` since it's a broad that encompasses issues related to device communication # Using `OSError` since it's a broad that encompasses issues related to device communication
raise OSError( raise OSError(
f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}." f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
@ -359,7 +367,9 @@ class IntelRealSenseCamera:
self.is_connected = True self.is_connected = True
def read(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndarray, np.ndarray]: def read(
self, temporary_color: str | None = None
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
"""Read a frame from the camera returned in the format height x width x channels (e.g. 480 x 640 x 3) """Read a frame from the camera returned in the format height x width x channels (e.g. 480 x 640 x 3)
of type `np.uint8`, contrarily to the pytorch format which is float channel first. of type `np.uint8`, contrarily to the pytorch format which is float channel first.
@ -386,11 +396,15 @@ class IntelRealSenseCamera:
color_frame = frame.get_color_frame() color_frame = frame.get_color_frame()
if not color_frame: if not color_frame:
raise OSError(f"Can't capture color image from IntelRealSenseCamera({self.serial_number}).") raise OSError(
f"Can't capture color image from IntelRealSenseCamera({self.serial_number})."
)
color_image = np.asanyarray(color_frame.get_data()) color_image = np.asanyarray(color_frame.get_data())
requested_color_mode = self.color_mode if temporary_color is None else temporary_color requested_color_mode = (
self.color_mode if temporary_color is None else temporary_color
)
if requested_color_mode not in ["rgb", "bgr"]: if requested_color_mode not in ["rgb", "bgr"]:
raise ValueError( raise ValueError(
f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided." f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided."
@ -418,7 +432,9 @@ class IntelRealSenseCamera:
if self.use_depth: if self.use_depth:
depth_frame = frame.get_depth_frame() depth_frame = frame.get_depth_frame()
if not depth_frame: if not depth_frame:
raise OSError(f"Can't capture depth image from IntelRealSenseCamera({self.serial_number}).") raise OSError(
f"Can't capture depth image from IntelRealSenseCamera({self.serial_number})."
)
depth_map = np.asanyarray(depth_frame.get_data()) depth_map = np.asanyarray(depth_frame.get_data())
@ -460,7 +476,9 @@ class IntelRealSenseCamera:
# TODO(rcadene, aliberts): intelrealsense has diverged compared to opencv over here # TODO(rcadene, aliberts): intelrealsense has diverged compared to opencv over here
num_tries += 1 num_tries += 1
time.sleep(1 / self.fps) time.sleep(1 / self.fps)
if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()): if num_tries > self.fps and (
self.thread.ident is None or not self.thread.is_alive()
):
raise Exception( raise Exception(
"The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called." "The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called."
) )

View File

@ -45,10 +45,14 @@ from lerobot.common.utils.utils import capture_timestamp_utc
MAX_OPENCV_INDEX = 60 MAX_OPENCV_INDEX = 60
def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]: def find_cameras(
raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False
) -> list[dict]:
cameras = [] cameras = []
if platform.system() == "Linux": if platform.system() == "Linux":
print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports") print(
"Linux detected. Finding available camera indices through scanning '/dev/video*' ports"
)
possible_ports = [str(port) for port in Path("/dev").glob("video*")] possible_ports = [str(port) for port in Path("/dev").glob("video*")]
ports = _find_cameras(possible_ports, mock=mock) ports = _find_cameras(possible_ports, mock=mock)
for port in ports: for port in ports:
@ -180,7 +184,9 @@ def save_images_from_cameras(
dt_s = time.perf_counter() - now dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s) busy_wait(1 / fps - dt_s)
print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}") print(
f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}"
)
if time.perf_counter() - start_time > record_time_s: if time.perf_counter() - start_time > record_time_s:
break break
@ -237,7 +243,9 @@ class OpenCVCamera:
if platform.system() == "Linux": if platform.system() == "Linux":
if isinstance(self.camera_index, int): if isinstance(self.camera_index, int):
self.port = Path(f"/dev/video{self.camera_index}") self.port = Path(f"/dev/video{self.camera_index}")
elif isinstance(self.camera_index, str) and is_valid_unix_path(self.camera_index): elif isinstance(self.camera_index, str) and is_valid_unix_path(
self.camera_index
):
self.port = Path(self.camera_index) self.port = Path(self.camera_index)
# Retrieve the camera index from a potentially symlinked path # Retrieve the camera index from a potentially symlinked path
self.camera_index = get_camera_index_from_unix_port(self.port) self.camera_index = get_camera_index_from_unix_port(self.port)
@ -283,7 +291,9 @@ class OpenCVCamera:
def connect(self): def connect(self):
if self.is_connected: if self.is_connected:
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.") raise RobotDeviceAlreadyConnectedError(
f"OpenCVCamera({self.camera_index}) is already connected."
)
if self.mock: if self.mock:
import tests.cameras.mock_cv2 as cv2 import tests.cameras.mock_cv2 as cv2
@ -344,7 +354,9 @@ class OpenCVCamera:
actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT) actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30) # Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3): if self.fps is not None and not math.isclose(
self.fps, actual_fps, rel_tol=1e-3
):
# Using `OSError` since it's a broad that encompasses issues related to device communication # Using `OSError` since it's a broad that encompasses issues related to device communication
raise OSError( raise OSError(
f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}." f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}."
@ -386,7 +398,9 @@ class OpenCVCamera:
if not ret: if not ret:
raise OSError(f"Can't capture color image from camera {self.camera_index}.") raise OSError(f"Can't capture color image from camera {self.camera_index}.")
requested_color_mode = self.color_mode if temporary_color_mode is None else temporary_color_mode requested_color_mode = (
self.color_mode if temporary_color_mode is None else temporary_color_mode
)
if requested_color_mode not in ["rgb", "bgr"]: if requested_color_mode not in ["rgb", "bgr"]:
raise ValueError( raise ValueError(

View File

@ -39,7 +39,9 @@ from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import get_safe_torch_device, has_method from lerobot.common.utils.utils import get_safe_torch_device, has_method
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None): def log_control_info(
robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None
):
log_items = [] log_items = []
if episode_index is not None: if episode_index is not None:
log_items.append(f"ep:{episode_index}") log_items.append(f"ep:{episode_index}")
@ -106,7 +108,9 @@ def predict_action(observation, policy, device, use_amp):
observation = copy(observation) observation = copy(observation)
with ( with (
torch.inference_mode(), torch.inference_mode(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), torch.autocast(device_type=device.type)
if device.type == "cuda" and use_amp
else nullcontext(),
): ):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation: for name in observation:
@ -162,7 +166,9 @@ def init_keyboard_listener(assign_rewards=False):
print("Right arrow key pressed. Exiting loop...") print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True events["exit_early"] = True
elif key == keyboard.Key.left: elif key == keyboard.Key.left:
print("Left arrow key pressed. Exiting loop and rerecord the last episode...") print(
"Left arrow key pressed. Exiting loop and rerecord the last episode..."
)
events["rerecord_episode"] = True events["rerecord_episode"] = True
events["exit_early"] = True events["exit_early"] = True
elif key == keyboard.Key.esc: elif key == keyboard.Key.esc:
@ -256,7 +262,9 @@ def control_loop(
raise ValueError("You need to provide a task as argument in `single_task`.") raise ValueError("You need to provide a task as argument in `single_task`.")
if dataset is not None and fps is not None and dataset.fps != fps: if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).") raise ValueError(
f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps})."
)
timestamp = 0 timestamp = 0
start_episode_t = time.perf_counter() start_episode_t = time.perf_counter()
@ -291,7 +299,9 @@ def control_loop(
if display_cameras and not is_headless(): if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key] image_keys = [key for key in observation if "image" in key]
for key in image_keys: for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) cv2.imshow(
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
)
cv2.waitKey(1) cv2.waitKey(1)
if fps is not None: if fps is not None:
@ -361,7 +371,11 @@ def sanity_check_dataset_name(repo_id, policy_cfg):
def sanity_check_dataset_robot_compatibility( def sanity_check_dataset_robot_compatibility(
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool, extra_features: dict = None dataset: LeRobotDataset,
robot: Robot,
fps: int,
use_videos: bool,
extra_features: dict = None,
) -> None: ) -> None:
features_from_robot = get_features_from_robot(robot, use_videos) features_from_robot = get_features_from_robot(robot, use_videos)
if extra_features is not None: if extra_features is not None:
@ -375,11 +389,14 @@ def sanity_check_dataset_robot_compatibility(
mismatches = [] mismatches = []
for field, dataset_value, present_value in fields: for field, dataset_value, present_value in fields:
diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"]) diff = DeepDiff(
dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"]
)
if diff: if diff:
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}") mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
if mismatches: if mismatches:
raise ValueError( raise ValueError(
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches) "Dataset metadata compatibility check failed with mismatches:\n"
+ "\n".join(mismatches)
) )

View File

@ -158,7 +158,9 @@ NUM_READ_RETRY = 10
NUM_WRITE_RETRY = 10 NUM_WRITE_RETRY = 10
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray: def convert_degrees_to_steps(
degrees: float | np.ndarray, models: str | list[str]
) -> np.ndarray:
"""This function converts the degree range to the step range for indicating motors rotation. """This function converts the degree range to the step range for indicating motors rotation.
It assumes a motor achieves a full rotation by going from -180 degree position to +180. It assumes a motor achieves a full rotation by going from -180 degree position to +180.
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation. The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
@ -384,7 +386,9 @@ class DynamixelMotorsBus:
indices = [] indices = []
for idx in tqdm.tqdm(possible_ids): for idx in tqdm.tqdm(possible_ids):
try: try:
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0] present_idx = self.read_with_motor_ids(
self.motor_models, [idx], "ID", num_retry=num_retry
)[0]
except ConnectionError: except ConnectionError:
continue continue
@ -400,7 +404,9 @@ class DynamixelMotorsBus:
def set_bus_baudrate(self, baudrate): def set_bus_baudrate(self, baudrate):
present_bus_baudrate = self.port_handler.getBaudRate() present_bus_baudrate = self.port_handler.getBaudRate()
if present_bus_baudrate != baudrate: if present_bus_baudrate != baudrate:
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.") print(
f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}."
)
self.port_handler.setBaudRate(baudrate) self.port_handler.setBaudRate(baudrate)
if self.port_handler.getBaudRate() != baudrate: if self.port_handler.getBaudRate() != baudrate:
@ -421,7 +427,9 @@ class DynamixelMotorsBus:
def set_calibration(self, calibration: dict[str, list]): def set_calibration(self, calibration: dict[str, list]):
self.calibration = calibration self.calibration = calibration
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None): def apply_calibration_autocorrect(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""This function applies the calibration, automatically detects out of range errors for motors values and attempts to correct. """This function applies the calibration, automatically detects out of range errors for motors values and attempts to correct.
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`. For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
@ -434,7 +442,9 @@ class DynamixelMotorsBus:
values = self.apply_calibration(values, motor_names) values = self.apply_calibration(values, motor_names)
return values return values
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): def apply_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with """Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
a "zero position" at 0 degree. a "zero position" at 0 degree.
@ -509,7 +519,9 @@ class DynamixelMotorsBus:
return values return values
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): def autocorrect_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""This function automatically detects issues with values of motors after calibration, and correct for these issues. """This function automatically detects issues with values of motors after calibration, and correct for these issues.
Some motors might have values outside of expected maximum bounds after calibration. Some motors might have values outside of expected maximum bounds after calibration.
@ -551,15 +563,23 @@ class DynamixelMotorsBus:
values[i] *= -1 values[i] *= -1
# Convert from initial range to range [-180, 180] degrees # Convert from initial range to range [-180, 180] degrees
calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE calib_val = (
in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE) (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
)
in_range = (calib_val > LOWER_BOUND_DEGREE) and (
calib_val < UPPER_BOUND_DEGREE
)
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees # Solve this inequality to find the factor to shift the range into [-180, 180] degrees
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE # values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE # - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
# (- (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= ((resolution // 2) - values[i] - homing_offset) / resolution # (- (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= ((resolution // 2) - values[i] - homing_offset) / resolution
low_factor = (-(resolution // 2) - values[i] - homing_offset) / resolution low_factor = (
upp_factor = ((resolution // 2) - values[i] - homing_offset) / resolution -(resolution // 2) - values[i] - homing_offset
) / resolution
upp_factor = (
(resolution // 2) - values[i] - homing_offset
) / resolution
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
start_pos = self.calibration["start_pos"][calib_idx] start_pos = self.calibration["start_pos"][calib_idx]
@ -567,7 +587,9 @@ class DynamixelMotorsBus:
# Convert from initial range to range [0, 100] in % # Convert from initial range to range [0, 100] in %
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100 calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR) in_range = (calib_val > LOWER_BOUND_LINEAR) and (
calib_val < UPPER_BOUND_LINEAR
)
# Solve this inequality to find the factor to shift the range into [0, 100] % # Solve this inequality to find the factor to shift the range into [0, 100] %
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100 # values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
@ -583,19 +605,27 @@ class DynamixelMotorsBus:
factor = math.ceil(low_factor) factor = math.ceil(low_factor)
if factor > upp_factor: if factor > upp_factor:
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]") raise ValueError(
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
else: else:
factor = math.ceil(upp_factor) factor = math.ceil(upp_factor)
if factor > low_factor: if factor > low_factor:
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]") raise ValueError(
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" out_of_range_str = (
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
in_range_str = (
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
logging.warning( logging.warning(
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, " f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
@ -605,7 +635,9 @@ class DynamixelMotorsBus:
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096. # A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
self.calibration["homing_offset"][calib_idx] += resolution * factor self.calibration["homing_offset"][calib_idx] += resolution * factor
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): def revert_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""Inverse of `apply_calibration`.""" """Inverse of `apply_calibration`."""
if motor_names is None: if motor_names is None:
motor_names = self.motor_names motor_names = self.motor_names
@ -644,7 +676,9 @@ class DynamixelMotorsBus:
values = np.round(values).astype(np.int32) values = np.round(values).astype(np.int32)
return values return values
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY): def read_with_motor_ids(
self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY
):
if self.mock: if self.mock:
import tests.motors.mock_dynamixel_sdk as dxl import tests.motors.mock_dynamixel_sdk as dxl
else: else:
@ -746,7 +780,9 @@ class DynamixelMotorsBus:
values = self.apply_calibration_autocorrect(values, motor_names) values = self.apply_calibration_autocorrect(values, motor_names)
# log the number of seconds it took to read the data from the motors # log the number of seconds it took to read the data from the motors
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names) delta_ts_name = get_log_name(
"delta_timestamp_s", "read", data_name, motor_names
)
self.logs[delta_ts_name] = time.perf_counter() - start_time self.logs[delta_ts_name] = time.perf_counter() - start_time
# log the utc time at which the data was received # log the utc time at which the data was received
@ -755,7 +791,9 @@ class DynamixelMotorsBus:
return values return values
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY): def write_with_motor_ids(
self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY
):
if self.mock: if self.mock:
import tests.motors.mock_dynamixel_sdk as dxl import tests.motors.mock_dynamixel_sdk as dxl
else: else:
@ -784,7 +822,12 @@ class DynamixelMotorsBus:
f"{self.packet_handler.getTxRxResult(comm)}" f"{self.packet_handler.getTxRxResult(comm)}"
) )
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None): def write(
self,
data_name,
values: int | float | np.ndarray,
motor_names: str | list[str] | None = None,
):
if not self.is_connected: if not self.is_connected:
raise RobotDeviceNotConnectedError( raise RobotDeviceNotConnectedError(
f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`." f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
@ -845,7 +888,9 @@ class DynamixelMotorsBus:
) )
# log the number of seconds it took to write the data to the motors # log the number of seconds it took to write the data to the motors
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names) delta_ts_name = get_log_name(
"delta_timestamp_s", "write", data_name, motor_names
)
self.logs[delta_ts_name] = time.perf_counter() - start_time self.logs[delta_ts_name] = time.perf_counter() - start_time
# TODO(rcadene): should we log the time before sending the write command? # TODO(rcadene): should we log the time before sending the write command?

View File

@ -137,7 +137,9 @@ NUM_READ_RETRY = 20
NUM_WRITE_RETRY = 20 NUM_WRITE_RETRY = 20
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray: def convert_degrees_to_steps(
degrees: float | np.ndarray, models: str | list[str]
) -> np.ndarray:
"""This function converts the degree range to the step range for indicating motors rotation. """This function converts the degree range to the step range for indicating motors rotation.
It assumes a motor achieves a full rotation by going from -180 degree position to +180. It assumes a motor achieves a full rotation by going from -180 degree position to +180.
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation. The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
@ -365,7 +367,9 @@ class FeetechMotorsBus:
indices = [] indices = []
for idx in tqdm.tqdm(possible_ids): for idx in tqdm.tqdm(possible_ids):
try: try:
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0] present_idx = self.read_with_motor_ids(
self.motor_models, [idx], "ID", num_retry=num_retry
)[0]
except ConnectionError: except ConnectionError:
continue continue
@ -381,7 +385,9 @@ class FeetechMotorsBus:
def set_bus_baudrate(self, baudrate): def set_bus_baudrate(self, baudrate):
present_bus_baudrate = self.port_handler.getBaudRate() present_bus_baudrate = self.port_handler.getBaudRate()
if present_bus_baudrate != baudrate: if present_bus_baudrate != baudrate:
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.") print(
f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}."
)
self.port_handler.setBaudRate(baudrate) self.port_handler.setBaudRate(baudrate)
if self.port_handler.getBaudRate() != baudrate: if self.port_handler.getBaudRate() != baudrate:
@ -402,7 +408,9 @@ class FeetechMotorsBus:
def set_calibration(self, calibration: dict[str, list]): def set_calibration(self, calibration: dict[str, list]):
self.calibration = calibration self.calibration = calibration
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None): def apply_calibration_autocorrect(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""This function apply the calibration, automatically detects out of range errors for motors values and attempt to correct. """This function apply the calibration, automatically detects out of range errors for motors values and attempt to correct.
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`. For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
@ -415,7 +423,9 @@ class FeetechMotorsBus:
values = self.apply_calibration(values, motor_names) values = self.apply_calibration(values, motor_names)
return values return values
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): def apply_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with """Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
a "zero position" at 0 degree. a "zero position" at 0 degree.
@ -489,7 +499,9 @@ class FeetechMotorsBus:
return values return values
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): def autocorrect_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""This function automatically detects issues with values of motors after calibration, and correct for these issues. """This function automatically detects issues with values of motors after calibration, and correct for these issues.
Some motors might have values outside of expected maximum bounds after calibration. Some motors might have values outside of expected maximum bounds after calibration.
@ -528,18 +540,26 @@ class FeetechMotorsBus:
values[i] *= -1 values[i] *= -1
# Convert from initial range to range [-180, 180] degrees # Convert from initial range to range [-180, 180] degrees
calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE calib_val = (
in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE) (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
)
in_range = (calib_val > LOWER_BOUND_DEGREE) and (
calib_val < UPPER_BOUND_DEGREE
)
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees # Solve this inequality to find the factor to shift the range into [-180, 180] degrees
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE # values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE # - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
# (- HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= (HALF_TURN_DEGREE / 180 * (resolution // 2) - values[i] - homing_offset) / resolution # (- HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= (HALF_TURN_DEGREE / 180 * (resolution // 2) - values[i] - homing_offset) / resolution
low_factor = ( low_factor = (
-HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset -HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2)
- values[i]
- homing_offset
) / resolution ) / resolution
upp_factor = ( upp_factor = (
HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2)
- values[i]
- homing_offset
) / resolution ) / resolution
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
@ -548,7 +568,9 @@ class FeetechMotorsBus:
# Convert from initial range to range [0, 100] in % # Convert from initial range to range [0, 100] in %
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100 calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR) in_range = (calib_val > LOWER_BOUND_LINEAR) and (
calib_val < UPPER_BOUND_LINEAR
)
# Solve this inequality to find the factor to shift the range into [0, 100] % # Solve this inequality to find the factor to shift the range into [0, 100] %
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100 # values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
@ -564,19 +586,27 @@ class FeetechMotorsBus:
factor = math.ceil(low_factor) factor = math.ceil(low_factor)
if factor > upp_factor: if factor > upp_factor:
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]") raise ValueError(
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
else: else:
factor = math.ceil(upp_factor) factor = math.ceil(upp_factor)
if factor > low_factor: if factor > low_factor:
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]") raise ValueError(
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" out_of_range_str = (
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
in_range_str = (
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
logging.warning( logging.warning(
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, " f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
@ -586,7 +616,9 @@ class FeetechMotorsBus:
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096. # A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
self.calibration["homing_offset"][calib_idx] += resolution * factor self.calibration["homing_offset"][calib_idx] += resolution * factor
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): def revert_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""Inverse of `apply_calibration`.""" """Inverse of `apply_calibration`."""
if motor_names is None: if motor_names is None:
motor_names = self.motor_names motor_names = self.motor_names
@ -662,7 +694,9 @@ class FeetechMotorsBus:
return values return values
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY): def read_with_motor_ids(
self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY
):
if self.mock: if self.mock:
import tests.motors.mock_scservo_sdk as scs import tests.motors.mock_scservo_sdk as scs
else: else:
@ -771,7 +805,9 @@ class FeetechMotorsBus:
values = self.apply_calibration_autocorrect(values, motor_names) values = self.apply_calibration_autocorrect(values, motor_names)
# log the number of seconds it took to read the data from the motors # log the number of seconds it took to read the data from the motors
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names) delta_ts_name = get_log_name(
"delta_timestamp_s", "read", data_name, motor_names
)
self.logs[delta_ts_name] = time.perf_counter() - start_time self.logs[delta_ts_name] = time.perf_counter() - start_time
# log the utc time at which the data was received # log the utc time at which the data was received
@ -780,7 +816,9 @@ class FeetechMotorsBus:
return values return values
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY): def write_with_motor_ids(
self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY
):
if self.mock: if self.mock:
import tests.motors.mock_scservo_sdk as scs import tests.motors.mock_scservo_sdk as scs
else: else:
@ -809,7 +847,12 @@ class FeetechMotorsBus:
f"{self.packet_handler.getTxRxResult(comm)}" f"{self.packet_handler.getTxRxResult(comm)}"
) )
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None): def write(
self,
data_name,
values: int | float | np.ndarray,
motor_names: str | list[str] | None = None,
):
if not self.is_connected: if not self.is_connected:
raise RobotDeviceNotConnectedError( raise RobotDeviceNotConnectedError(
f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`." f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
@ -870,7 +913,9 @@ class FeetechMotorsBus:
) )
# log the number of seconds it took to write the data to the motors # log the number of seconds it took to write the data to the motors
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names) delta_ts_name = get_log_name(
"delta_timestamp_s", "write", data_name, motor_names
)
self.logs[delta_ts_name] = time.perf_counter() - start_time self.logs[delta_ts_name] = time.perf_counter() - start_time
# TODO(rcadene): should we log the time before sending the write command? # TODO(rcadene): should we log the time before sending the write command?

View File

@ -24,9 +24,7 @@ from lerobot.common.robot_devices.motors.dynamixel import (
) )
from lerobot.common.robot_devices.motors.utils import MotorsBus from lerobot.common.robot_devices.motors.utils import MotorsBus
URL_TEMPLATE = ( URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
)
# The following positions are provided in nominal degree range ]-180, +180[ # The following positions are provided in nominal degree range ]-180, +180[
# For more info on these constants, see comments in the code where they get used. # For more info on these constants, see comments in the code where they get used.
@ -37,7 +35,9 @@ ROTATED_POSITION_DEGREE = 90
def assert_drive_mode(drive_mode): def assert_drive_mode(drive_mode):
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted. # `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
if not np.all(np.isin(drive_mode, [0, 1])): if not np.all(np.isin(drive_mode, [0, 1])):
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})") raise ValueError(
f"`drive_mode` contains values other than 0 or 1: ({drive_mode})"
)
def apply_drive_mode(position, drive_mode): def apply_drive_mode(position, drive_mode):
@ -78,12 +78,16 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
``` ```
""" """
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError("To run calibration, the torque must be disabled on all motors.") raise ValueError(
"To run calibration, the torque must be disabled on all motors."
)
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to zero position") print("\nMove arm to zero position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")) print(
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")
)
input("Press Enter to continue...") input("Press Enter to continue...")
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed. # We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
@ -104,10 +108,15 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view # corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
# of the previous motor in the kinetic chain. # of the previous motor in the kinetic chain.
print("\nMove arm to rotated target position") print("\nMove arm to rotated target position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")) print(
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")
)
input("Press Enter to continue...") input("Press Enter to continue...")
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models) rotated_target_pos = convert_degrees_to_steps(
ROTATED_POSITION_DEGREE, arm.motor_models
)
# Find drive mode by rotating each motor by a quarter of a turn. # Find drive mode by rotating each motor by a quarter of a turn.
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0). # Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
@ -116,11 +125,15 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
# Re-compute homing offset to take into account drive mode # Re-compute homing offset to take into account drive mode
rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode) rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode)
rotated_nearest_pos = compute_nearest_rounded_position(rotated_drived_pos, arm.motor_models) rotated_nearest_pos = compute_nearest_rounded_position(
rotated_drived_pos, arm.motor_models
)
homing_offset = rotated_target_pos - rotated_nearest_pos homing_offset = rotated_target_pos - rotated_nearest_pos
print("\nMove arm to rest position") print("\nMove arm to rest position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")) print(
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")
)
input("Press Enter to continue...") input("Press Enter to continue...")
print() print()

View File

@ -26,9 +26,7 @@ from lerobot.common.robot_devices.motors.feetech import (
) )
from lerobot.common.robot_devices.motors.utils import MotorsBus from lerobot.common.robot_devices.motors.utils import MotorsBus
URL_TEMPLATE = ( URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
)
# The following positions are provided in nominal degree range ]-180, +180[ # The following positions are provided in nominal degree range ]-180, +180[
# For more info on these constants, see comments in the code where they get used. # For more info on these constants, see comments in the code where they get used.
@ -39,7 +37,9 @@ ROTATED_POSITION_DEGREE = 90
def assert_drive_mode(drive_mode): def assert_drive_mode(drive_mode):
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted. # `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
if not np.all(np.isin(drive_mode, [0, 1])): if not np.all(np.isin(drive_mode, [0, 1])):
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})") raise ValueError(
f"`drive_mode` contains values other than 0 or 1: ({drive_mode})"
)
def apply_drive_mode(position, drive_mode): def apply_drive_mode(position, drive_mode):
@ -140,7 +140,9 @@ def apply_offset(calib, offset):
return calib return calib
def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): def run_arm_auto_calibration(
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
if robot_type == "so100": if robot_type == "so100":
return run_arm_auto_calibration_so100(arm, robot_type, arm_name, arm_type) return run_arm_auto_calibration_so100(arm, robot_type, arm_name, arm_type)
elif robot_type == "moss": elif robot_type == "moss":
@ -149,18 +151,27 @@ def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm
raise ValueError(robot_type) raise ValueError(robot_type)
def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): def run_arm_auto_calibration_so100(
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms""" """All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError("To run calibration, the torque must be disabled on all motors.") raise ValueError(
"To run calibration, the torque must be disabled on all motors."
)
if not (robot_type == "so100" and arm_type == "follower"): if not (robot_type == "so100" and arm_type == "follower"):
raise NotImplementedError("Auto calibration only supports the follower of so100 arms for now.") raise NotImplementedError(
"Auto calibration only supports the follower of so100 arms for now."
)
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to initial position") print("\nMove arm to initial position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")) print(
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")
)
input("Press Enter to continue...") input("Press Enter to continue...")
# Lower the acceleration of the motors (in [0,254]) # Lower the acceleration of the motors (in [0,254])
@ -207,11 +218,16 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
print("Calibrate elbow_flex") print("Calibrate elbow_flex")
calib["elbow_flex"] = move_to_calibrate( calib["elbow_flex"] = move_to_calibrate(
arm, "elbow_flex", positive_first=False, in_between_move_hook=in_between_move_hook arm,
"elbow_flex",
positive_first=False,
in_between_move_hook=in_between_move_hook,
) )
calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024) calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024)
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex") arm.write(
"Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex"
)
time.sleep(1) time.sleep(1)
def in_between_move_hook(): def in_between_move_hook():
@ -239,18 +255,30 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
} }
arm.write("Goal_Position", list(positions.values()), list(positions.keys())) arm.write("Goal_Position", list(positions.values()), list(positions.keys()))
arm.write("Goal_Position", round(calib["shoulder_lift"]["zero_pos"] - 1600), "shoulder_lift") arm.write(
"Goal_Position",
round(calib["shoulder_lift"]["zero_pos"] - 1600),
"shoulder_lift",
)
time.sleep(2) time.sleep(2)
arm.write("Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex") arm.write(
"Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex"
)
time.sleep(2) time.sleep(2)
arm.write("Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex") arm.write(
"Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex"
)
time.sleep(2) time.sleep(2)
arm.write("Goal_Position", round(calib["gripper"]["end_pos"]), "gripper") arm.write("Goal_Position", round(calib["gripper"]["end_pos"]), "gripper")
time.sleep(2) time.sleep(2)
print("Calibrate wrist_roll") print("Calibrate wrist_roll")
calib["wrist_roll"] = move_to_calibrate( calib["wrist_roll"] = move_to_calibrate(
arm, "wrist_roll", invert_drive_mode=True, positive_first=False, while_move_hook=while_move_hook arm,
"wrist_roll",
invert_drive_mode=True,
positive_first=False,
while_move_hook=while_move_hook,
) )
arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"], "wrist_roll") arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"], "wrist_roll")
@ -260,7 +288,9 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex") arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex")
time.sleep(1) time.sleep(1)
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 2048, "elbow_flex") arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 2048, "elbow_flex")
arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift") arm.write(
"Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift"
)
time.sleep(1) time.sleep(1)
arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan") arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan")
time.sleep(1) time.sleep(1)
@ -289,18 +319,27 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
return calib_dict return calib_dict
def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): def run_arm_auto_calibration_moss(
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms""" """All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError("To run calibration, the torque must be disabled on all motors.") raise ValueError(
"To run calibration, the torque must be disabled on all motors."
)
if not (robot_type == "moss" and arm_type == "follower"): if not (robot_type == "moss" and arm_type == "follower"):
raise NotImplementedError("Auto calibration only supports the follower of moss arms for now.") raise NotImplementedError(
"Auto calibration only supports the follower of moss arms for now."
)
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to initial position") print("\nMove arm to initial position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")) print(
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")
)
input("Press Enter to continue...") input("Press Enter to continue...")
# Lower the acceleration of the motors (in [0,254]) # Lower the acceleration of the motors (in [0,254])
@ -384,8 +423,12 @@ def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex") arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
time.sleep(1) time.sleep(1)
arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift") arm.write(
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex") "Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift"
)
arm.write(
"Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex"
)
time.sleep(2) time.sleep(2)
calib_modes = [] calib_modes = []
@ -412,7 +455,9 @@ def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str
return calib_dict return calib_dict
def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): def run_arm_manual_calibration(
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
"""This function ensures that a neural network trained on data collected on a given robot """This function ensures that a neural network trained on data collected on a given robot
can work on another robot. For instance before calibration, setting a same goal position can work on another robot. For instance before calibration, setting a same goal position
for each motor of two different robots will get two very different positions. But after calibration, for each motor of two different robots will get two very different positions. But after calibration,
@ -435,12 +480,16 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a
``` ```
""" """
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError("To run calibration, the torque must be disabled on all motors.") raise ValueError(
"To run calibration, the torque must be disabled on all motors."
)
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to zero position") print("\nMove arm to zero position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")) print(
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")
)
input("Press Enter to continue...") input("Press Enter to continue...")
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed. # We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
@ -460,10 +509,15 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view # corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
# of the previous motor in the kinetic chain. # of the previous motor in the kinetic chain.
print("\nMove arm to rotated target position") print("\nMove arm to rotated target position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")) print(
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")
)
input("Press Enter to continue...") input("Press Enter to continue...")
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models) rotated_target_pos = convert_degrees_to_steps(
ROTATED_POSITION_DEGREE, arm.motor_models
)
# Find drive mode by rotating each motor by a quarter of a turn. # Find drive mode by rotating each motor by a quarter of a turn.
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0). # Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
@ -475,7 +529,9 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a
homing_offset = rotated_target_pos - rotated_drived_pos homing_offset = rotated_target_pos - rotated_drived_pos
print("\nMove arm to rest position") print("\nMove arm to rest position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")) print(
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")
)
input("Press Enter to continue...") input("Press Enter to continue...")
print() print()

View File

@ -31,11 +31,16 @@ from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig
from lerobot.common.robot_devices.robots.utils import get_arm_id from lerobot.common.robot_devices.robots.utils import get_arm_id
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
)
def ensure_safe_goal_position( def ensure_safe_goal_position(
goal_pos: torch.Tensor, present_pos: torch.Tensor, max_relative_target: float | list[float] goal_pos: torch.Tensor,
present_pos: torch.Tensor,
max_relative_target: float | list[float],
): ):
# Cap relative action target magnitude for safety. # Cap relative action target magnitude for safety.
diff = goal_pos - present_pos diff = goal_pos - present_pos
@ -277,7 +282,9 @@ class ManipulatorRobot:
# to squeeze the gripper and have it spring back to an open position on its own. # to squeeze the gripper and have it spring back to an open position on its own.
for name in self.leader_arms: for name in self.leader_arms:
self.leader_arms[name].write("Torque_Enable", 1, "gripper") self.leader_arms[name].write("Torque_Enable", 1, "gripper")
self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper") self.leader_arms[name].write(
"Goal_Position", self.config.gripper_open_degree, "gripper"
)
# Check both arms can be read # Check both arms can be read
for name in self.follower_arms: for name in self.follower_arms:
@ -309,18 +316,26 @@ class ManipulatorRobot:
print(f"Missing calibration file '{arm_calib_path}'") print(f"Missing calibration file '{arm_calib_path}'")
if self.robot_type in ["koch", "koch_bimanual", "aloha"]: if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
from lerobot.common.robot_devices.robots.dynamixel_calibration import run_arm_calibration from lerobot.common.robot_devices.robots.dynamixel_calibration import (
run_arm_calibration,
)
calibration = run_arm_calibration(arm, self.robot_type, name, arm_type) calibration = run_arm_calibration(
arm, self.robot_type, name, arm_type
)
elif self.robot_type in ["so100", "moss", "lekiwi"]: elif self.robot_type in ["so100", "moss", "lekiwi"]:
from lerobot.common.robot_devices.robots.feetech_calibration import ( from lerobot.common.robot_devices.robots.feetech_calibration import (
run_arm_manual_calibration, run_arm_manual_calibration,
) )
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type) calibration = run_arm_manual_calibration(
arm, self.robot_type, name, arm_type
)
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'") print(
f"Calibration is done! Saving calibration file '{arm_calib_path}'"
)
arm_calib_path.parent.mkdir(parents=True, exist_ok=True) arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
with open(arm_calib_path, "w") as f: with open(arm_calib_path, "w") as f:
json.dump(calibration, f) json.dump(calibration, f)
@ -339,13 +354,17 @@ class ManipulatorRobot:
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError("To run set robot preset, the torque must be disabled on all motors.") raise ValueError(
"To run set robot preset, the torque must be disabled on all motors."
)
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't # Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't
# rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm, # rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm,
# you could end up with a servo with a position 0 or 4095 at a crucial point See [ # you could end up with a servo with a position 0 or 4095 at a crucial point See [
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11] # https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
all_motors_except_gripper = [name for name in arm.motor_names if name != "gripper"] all_motors_except_gripper = [
name for name in arm.motor_names if name != "gripper"
]
if len(all_motors_except_gripper) > 0: if len(all_motors_except_gripper) > 0:
# 4 corresponds to Extended Position on Koch motors # 4 corresponds to Extended Position on Koch motors
arm.write("Operating_Mode", 4, all_motors_except_gripper) arm.write("Operating_Mode", 4, all_motors_except_gripper)
@ -374,7 +393,9 @@ class ManipulatorRobot:
# Enable torque on the gripper of the leader arms, and move it to 45 degrees, # Enable torque on the gripper of the leader arms, and move it to 45 degrees,
# so that we can use it as a trigger to close the gripper of the follower arms. # so that we can use it as a trigger to close the gripper of the follower arms.
self.leader_arms[name].write("Torque_Enable", 1, "gripper") self.leader_arms[name].write("Torque_Enable", 1, "gripper")
self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper") self.leader_arms[name].write(
"Goal_Position", self.config.gripper_open_degree, "gripper"
)
def set_aloha_robot_preset(self): def set_aloha_robot_preset(self):
def set_shadow_(arm): def set_shadow_(arm):
@ -404,11 +425,15 @@ class ManipulatorRobot:
# you could end up with a servo with a position 0 or 4095 at a crucial point See [ # you could end up with a servo with a position 0 or 4095 at a crucial point See [
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11] # https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
all_motors_except_gripper = [ all_motors_except_gripper = [
name for name in self.follower_arms[name].motor_names if name != "gripper" name
for name in self.follower_arms[name].motor_names
if name != "gripper"
] ]
if len(all_motors_except_gripper) > 0: if len(all_motors_except_gripper) > 0:
# 4 corresponds to Extended Position on Aloha motors # 4 corresponds to Extended Position on Aloha motors
self.follower_arms[name].write("Operating_Mode", 4, all_motors_except_gripper) self.follower_arms[name].write(
"Operating_Mode", 4, all_motors_except_gripper
)
# Use 'position control current based' for follower gripper to be limited by the limit of the current. # Use 'position control current based' for follower gripper to be limited by the limit of the current.
# It can grasp an object without forcing too much even tho, # It can grasp an object without forcing too much even tho,
@ -456,7 +481,9 @@ class ManipulatorRobot:
before_lread_t = time.perf_counter() before_lread_t = time.perf_counter()
leader_pos[name] = self.leader_arms[name].read("Present_Position") leader_pos[name] = self.leader_arms[name].read("Present_Position")
leader_pos[name] = torch.from_numpy(leader_pos[name]) leader_pos[name] = torch.from_numpy(leader_pos[name])
self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t self.logs[f"read_leader_{name}_pos_dt_s"] = (
time.perf_counter() - before_lread_t
)
# Send goal position to the follower # Send goal position to the follower
follower_goal_pos = {} follower_goal_pos = {}
@ -477,14 +504,18 @@ class ManipulatorRobot:
if self.config.max_relative_target is not None: if self.config.max_relative_target is not None:
present_pos = self.follower_arms[name].read("Present_Position") present_pos = self.follower_arms[name].read("Present_Position")
present_pos = torch.from_numpy(present_pos) present_pos = torch.from_numpy(present_pos)
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target) goal_pos = ensure_safe_goal_position(
goal_pos, present_pos, self.config.max_relative_target
)
# Used when record_data=True # Used when record_data=True
follower_goal_pos[name] = goal_pos follower_goal_pos[name] = goal_pos
goal_pos = goal_pos.numpy().astype(np.float32) goal_pos = goal_pos.numpy().astype(np.float32)
self.follower_arms[name].write("Goal_Position", goal_pos) self.follower_arms[name].write("Goal_Position", goal_pos)
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t self.logs[f"write_follower_{name}_goal_pos_dt_s"] = (
time.perf_counter() - before_fwrite_t
)
# Early exit when recording data is not requested # Early exit when recording data is not requested
if not record_data: if not record_data:
@ -497,7 +528,9 @@ class ManipulatorRobot:
before_fread_t = time.perf_counter() before_fread_t = time.perf_counter()
follower_pos[name] = self.follower_arms[name].read("Present_Position") follower_pos[name] = self.follower_arms[name].read("Present_Position")
follower_pos[name] = torch.from_numpy(follower_pos[name]) follower_pos[name] = torch.from_numpy(follower_pos[name])
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t self.logs[f"read_follower_{name}_pos_dt_s"] = (
time.perf_counter() - before_fread_t
)
# Create state by concatenating follower current position # Create state by concatenating follower current position
state = [] state = []
@ -519,8 +552,12 @@ class ManipulatorRobot:
before_camread_t = time.perf_counter() before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read() images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name]) images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t "delta_timestamp_s"
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
# Populate output dictionaries # Populate output dictionaries
obs_dict, action_dict = {}, {} obs_dict, action_dict = {}, {}
@ -544,7 +581,9 @@ class ManipulatorRobot:
before_fread_t = time.perf_counter() before_fread_t = time.perf_counter()
follower_pos[name] = self.follower_arms[name].read("Present_Position") follower_pos[name] = self.follower_arms[name].read("Present_Position")
follower_pos[name] = torch.from_numpy(follower_pos[name]) follower_pos[name] = torch.from_numpy(follower_pos[name])
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t self.logs[f"read_follower_{name}_pos_dt_s"] = (
time.perf_counter() - before_fread_t
)
# Create state by concatenating follower current position # Create state by concatenating follower current position
state = [] state = []
@ -559,8 +598,12 @@ class ManipulatorRobot:
before_camread_t = time.perf_counter() before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read() images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name]) images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t "delta_timestamp_s"
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
# Populate output dictionaries and format to pytorch # Populate output dictionaries and format to pytorch
obs_dict = {} obs_dict = {}
@ -606,7 +649,9 @@ class ManipulatorRobot:
if self.config.max_relative_target is not None: if self.config.max_relative_target is not None:
present_pos = self.follower_arms[name].read("Present_Position") present_pos = self.follower_arms[name].read("Present_Position")
present_pos = torch.from_numpy(present_pos) present_pos = torch.from_numpy(present_pos)
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target) goal_pos = ensure_safe_goal_position(
goal_pos, present_pos, self.config.max_relative_target
)
# Save tensor to concat and return # Save tensor to concat and return
action_sent.append(goal_pos) action_sent.append(goal_pos)

View File

@ -52,7 +52,9 @@ class StretchRobot(StretchAPI):
def connect(self) -> None: def connect(self) -> None:
self.is_connected = self.startup() self.is_connected = self.startup()
if not self.is_connected: if not self.is_connected:
print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'") print(
"Another process is already using Stretch. Try running 'stretch_free_robot_process.py'"
)
raise ConnectionError() raise ConnectionError()
for name in self.cameras: for name in self.cameras:
@ -60,7 +62,9 @@ class StretchRobot(StretchAPI):
self.is_connected = self.is_connected and self.cameras[name].is_connected self.is_connected = self.is_connected and self.cameras[name].is_connected
if not self.is_connected: if not self.is_connected:
print("Could not connect to the cameras, check that all cameras are plugged-in.") print(
"Could not connect to the cameras, check that all cameras are plugged-in."
)
raise ConnectionError() raise ConnectionError()
self.run_calibration() self.run_calibration()
@ -105,8 +109,12 @@ class StretchRobot(StretchAPI):
before_camread_t = time.perf_counter() before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read() images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name]) images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t "delta_timestamp_s"
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
# Populate output dictionaries # Populate output dictionaries
obs_dict, action_dict = {}, {} obs_dict, action_dict = {}, {}
@ -150,8 +158,12 @@ class StretchRobot(StretchAPI):
before_camread_t = time.perf_counter() before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read() images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name]) images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t "delta_timestamp_s"
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
# Populate output dictionaries # Populate output dictionaries
obs_dict = {} obs_dict = {}

View File

@ -48,7 +48,8 @@ class RobotDeviceNotConnectedError(Exception):
"""Exception raised when the robot device is not connected.""" """Exception raised when the robot device is not connected."""
def __init__( def __init__(
self, message="This robot device is not connected. Try calling `robot_device.connect()` first." self,
message="This robot device is not connected. Try calling `robot_device.connect()` first.",
): ):
self.message = message self.message = message
super().__init__(self.message) super().__init__(self.message)

View File

@ -17,7 +17,9 @@ import importlib
import logging import logging
def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool: def is_package_available(
pkg_name: str, return_version: bool = False
) -> tuple[bool, str] | bool:
"""Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py """Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py
Check if the package spec exists and grab its version to avoid importing a local directory. Check if the package spec exists and grab its version to avoid importing a local directory.
**Note:** this doesn't work for all packages. **Note:** this doesn't work for all packages.

View File

@ -28,7 +28,9 @@ def write_video(video_path, stacked_frames, fps):
# Filter out DeprecationWarnings raised from pkg_resources # Filter out DeprecationWarnings raised from pkg_resources
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings( warnings.filterwarnings(
"ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning "ignore",
"pkg_resources is deprecated as an API",
category=DeprecationWarning,
) )
imageio.mimsave(video_path, stacked_frames, fps=fps) imageio.mimsave(video_path, stacked_frames, fps=fps)

View File

@ -148,7 +148,10 @@ def _relative_path_between(path1: Path, path2: Path) -> Path:
except ValueError: # most likely because path1 is not a subpath of path2 except ValueError: # most likely because path1 is not a subpath of path2
common_parts = Path(osp.commonpath([path1, path2])).parts common_parts = Path(osp.commonpath([path1, path2])).parts
return Path( return Path(
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :])) "/".join(
[".."] * (len(path2.parts) - len(common_parts))
+ list(path1.parts[len(common_parts) :])
)
) )
@ -159,10 +162,26 @@ def print_cuda_memory_usage():
gc.collect() gc.collect()
# Also clear the cache if you want to fully release the memory # Also clear the cache if you want to fully release the memory
torch.cuda.empty_cache() torch.cuda.empty_cache()
print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2)) print(
print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2)) "Current GPU Memory Allocated: {:.2f} MB".format(
print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2)) torch.cuda.memory_allocated(0) / 1024**2
print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2)) )
)
print(
"Maximum GPU Memory Allocated: {:.2f} MB".format(
torch.cuda.max_memory_allocated(0) / 1024**2
)
)
print(
"Current GPU Memory Reserved: {:.2f} MB".format(
torch.cuda.memory_reserved(0) / 1024**2
)
)
print(
"Maximum GPU Memory Reserved: {:.2f} MB".format(
torch.cuda.max_memory_reserved(0) / 1024**2
)
)
def capture_timestamp_utc(): def capture_timestamp_utc():
@ -232,7 +251,12 @@ def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
class TimerManager: class TimerManager:
def __init__(self, elapsed_time_list: list[float] | None = None, label="Elapsed time", log=True): def __init__(
self,
elapsed_time_list: list[float] | None = None,
label="Elapsed time",
log=True,
):
self.label = label self.label = label
self.elapsed_time_list = elapsed_time_list self.elapsed_time_list = elapsed_time_list
self.log = log self.log = log

View File

@ -9,7 +9,7 @@ env:
action_dim: 6 action_dim: 6
fps: ${fps} fps: ${fps}
device: mps device: mps
wrapper: wrapper:
crop_params_dict: crop_params_dict:
observation.images.front: [102, 43, 358, 523] observation.images.front: [102, 43, 358, 523]
@ -28,4 +28,4 @@ env:
reward_classifier: reward_classifier:
pretrained_path: outputs/classifier/13-02-random-sample-resnet10-frozen/checkpoints/best/pretrained_model pretrained_path: outputs/classifier/13-02-random-sample-resnet10-frozen/checkpoints/best/pretrained_model
config_path: lerobot/configs/policy/hilserl_classifier.yaml config_path: lerobot/configs/policy/hilserl_classifier.yaml

View File

@ -66,7 +66,7 @@ policy:
observation.image: [3, 64, 64] observation.image: [3, 64, 64]
output_shapes: output_shapes:
action: [7] action: [7]
camera_number: 1 camera_number: 1
# Normalization / Unnormalization # Normalization / Unnormalization
@ -79,7 +79,7 @@ policy:
# 1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00, # 1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00,
# -3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00, # -3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00,
# -6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01, # -6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01,
# 8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01] # 8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01]
# max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400, # max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400,
# 0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163, # 0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163,

View File

@ -108,20 +108,26 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
break break
if motor_index == -1: if motor_index == -1:
raise ValueError("No motors detected. Please ensure you have one motor connected.") raise ValueError(
"No motors detected. Please ensure you have one motor connected."
)
print(f"Motor index found at: {motor_index}") print(f"Motor index found at: {motor_index}")
if brand == "feetech": if brand == "feetech":
# Allows ID and BAUDRATE to be written in memory # Allows ID and BAUDRATE to be written in memory
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0) motor_bus.write_with_motor_ids(
motor_bus.motor_models, motor_index, "Lock", 0
)
if baudrate != baudrate_des: if baudrate != baudrate_des:
print(f"Setting its baudrate to {baudrate_des}") print(f"Setting its baudrate to {baudrate_des}")
baudrate_idx = list(series_baudrate_table.values()).index(baudrate_des) baudrate_idx = list(series_baudrate_table.values()).index(baudrate_des)
# The write can fail, so we allow retries # The write can fail, so we allow retries
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx) motor_bus.write_with_motor_ids(
motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx
)
time.sleep(0.5) time.sleep(0.5)
motor_bus.set_bus_baudrate(baudrate_des) motor_bus.set_bus_baudrate(baudrate_des)
present_baudrate_idx = motor_bus.read_with_motor_ids( present_baudrate_idx = motor_bus.read_with_motor_ids(
@ -136,7 +142,9 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0) motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "ID", motor_idx_des) motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "ID", motor_idx_des)
present_idx = motor_bus.read_with_motor_ids(motor_bus.motor_models, motor_idx_des, "ID", num_retry=2) present_idx = motor_bus.read_with_motor_ids(
motor_bus.motor_models, motor_idx_des, "ID", num_retry=2
)
if present_idx != motor_idx_des: if present_idx != motor_idx_des:
raise OSError("Failed to write index.") raise OSError("Failed to write index.")
@ -164,12 +172,29 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--port", type=str, required=True, help="Motors bus port (e.g. dynamixel,feetech)")
parser.add_argument("--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)")
parser.add_argument("--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)")
parser.add_argument("--ID", type=int, required=True, help="Desired ID of the current motor (e.g. 1,2,3)")
parser.add_argument( parser.add_argument(
"--baudrate", type=int, default=1000000, help="Desired baudrate for the motor (default: 1000000)" "--port",
type=str,
required=True,
help="Motors bus port (e.g. dynamixel,feetech)",
)
parser.add_argument(
"--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)"
)
parser.add_argument(
"--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)"
)
parser.add_argument(
"--ID",
type=int,
required=True,
help="Desired ID of the current motor (e.g. 1,2,3)",
)
parser.add_argument(
"--baudrate",
type=int,
default=1000000,
help="Desired baudrate for the motor (default: 1000000)",
) )
args = parser.parse_args() args = parser.parse_args()

View File

@ -149,7 +149,11 @@ def init_sim_calibration(robot, cfg):
axis_directions = np.array(cfg.get("axis_directions", [1])) axis_directions = np.array(cfg.get("axis_directions", [1]))
offsets = np.array(cfg.get("offsets", [0])) * np.pi offsets = np.array(cfg.get("offsets", [0])) * np.pi
return {"start_pos": start_pos, "axis_directions": axis_directions, "offsets": offsets} return {
"start_pos": start_pos,
"axis_directions": axis_directions,
"offsets": offsets,
}
def real_positions_to_sim(real_positions, axis_directions, start_pos, offsets): def real_positions_to_sim(real_positions, axis_directions, start_pos, offsets):
@ -170,7 +174,10 @@ def teleoperate(env, robot: Robot, process_action_fn, teleop_time_s=None):
leader_pos = robot.leader_arms.main.read("Present_Position") leader_pos = robot.leader_arms.main.read("Present_Position")
action = process_action_fn(leader_pos) action = process_action_fn(leader_pos)
env.step(np.expand_dims(action, 0)) env.step(np.expand_dims(action, 0))
if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s: if (
teleop_time_s is not None
and time.perf_counter() - start_teleop_t > teleop_time_s
):
print("Teleoperation processes finished.") print("Teleoperation processes finished.")
break break
@ -202,19 +209,27 @@ def record(
# Load pretrained policy # Load pretrained policy
extra_features = ( extra_features = (
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None {"next.reward": {"dtype": "int64", "shape": (1,), "names": None}}
if assign_rewards
else None
) )
policy = None policy = None
if pretrained_policy_name_or_path is not None: if pretrained_policy_name_or_path is not None:
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides) policy, policy_fps, device, use_amp = init_policy(
pretrained_policy_name_or_path, policy_overrides
)
if fps is None: if fps is None:
fps = policy_fps fps = policy_fps
logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).") logging.warning(
f"No fps provided, so using the fps from policy config ({policy_fps})."
)
if policy is None and process_action_from_leader is None: if policy is None and process_action_from_leader is None:
raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.") raise ValueError(
"Either policy or process_action_fn has to be set to enable control in sim."
)
# initialize listener before sim env # initialize listener before sim env
listener, events = init_keyboard_listener(assign_rewards=assign_rewards) listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
@ -256,7 +271,11 @@ def record(
"shape": env.observation_space[obs_key].shape, "shape": env.observation_space[obs_key].shape,
} }
features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None} features["action"] = {
"dtype": "float32",
"shape": env.action_space.shape,
"names": None,
}
features = {**features, **extra_features} features = {**features, **extra_features}
# Create empty dataset or load existing saved episodes # Create empty dataset or load existing saved episodes
@ -357,7 +376,9 @@ def record(
if events["stop_recording"] or recorded_episodes >= num_episodes: if events["stop_recording"] or recorded_episodes >= num_episodes:
break break
else: else:
logging.info("Waiting for a few seconds before starting next episode recording...") logging.info(
"Waiting for a few seconds before starting next episode recording..."
)
busy_wait(3) busy_wait(3)
log_say("Stop recording", play_sounds, blocking=True) log_say("Stop recording", play_sounds, blocking=True)
@ -375,7 +396,12 @@ def record(
def replay( def replay(
env, root: Path, repo_id: str, episode: int, fps: int | None = None, local_files_only: bool = True env,
root: Path,
repo_id: str,
episode: int,
fps: int | None = None,
local_files_only: bool = True,
): ):
env = env() env = env()
@ -422,7 +448,10 @@ if __name__ == "__main__":
parser_record = subparsers.add_parser("record", parents=[base_parser]) parser_record = subparsers.add_parser("record", parents=[base_parser])
parser_record.add_argument( parser_record.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" "--fps",
type=none_or_int,
default=None,
help="Frames per second (set to None to disable)",
) )
parser_record.add_argument( parser_record.add_argument(
"--root", "--root",
@ -448,7 +477,9 @@ if __name__ == "__main__":
required=True, required=True,
help="A description of the task preformed during recording that can be used as a language instruction.", help="A description of the task preformed during recording that can be used as a language instruction.",
) )
parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.") parser_record.add_argument(
"--num-episodes", type=int, default=50, help="Number of episodes to record."
)
parser_record.add_argument( parser_record.add_argument(
"--run-compute-stats", "--run-compute-stats",
type=int, type=int,
@ -509,7 +540,10 @@ if __name__ == "__main__":
parser_replay = subparsers.add_parser("replay", parents=[base_parser]) parser_replay = subparsers.add_parser("replay", parents=[base_parser])
parser_replay.add_argument( parser_replay.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" "--fps",
type=none_or_int,
default=None,
help="Frames per second (set to None to disable)",
) )
parser_replay.add_argument( parser_replay.add_argument(
"--root", "--root",
@ -523,7 +557,9 @@ if __name__ == "__main__":
default="lerobot/test", default="lerobot/test",
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).", help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
) )
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episodes to replay.") parser_replay.add_argument(
"--episode", type=int, default=0, help="Index of the episodes to replay."
)
args = parser.parse_args() args = parser.parse_args()

View File

@ -59,7 +59,11 @@ np_version = np.__version__ if HAS_NP else "N/A"
torch_version = torch.__version__ if HAS_TORCH else "N/A" torch_version = torch.__version__ if HAS_TORCH else "N/A"
torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A" torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A" cuda_version = (
torch._C._cuda_getCompiledVersion()
if HAS_TORCH and torch.version.cuda is not None
else "N/A"
)
# TODO(aliberts): refactor into an actual command `lerobot env` # TODO(aliberts): refactor into an actual command `lerobot env`
@ -77,7 +81,9 @@ def display_sys_info() -> dict:
"Using GPU in script?": "<fill in>", "Using GPU in script?": "<fill in>",
# "Using distributed or parallel set-up in script?": "<fill in>", # "Using distributed or parallel set-up in script?": "<fill in>",
} }
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n") print(
"\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n"
)
print(format_dict(info)) print(format_dict(info))
return info return info

View File

@ -170,7 +170,10 @@ def rollout(
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't # VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
# available of none of the envs finished. # available of none of the envs finished.
if "final_info" in info: if "final_info" in info:
successes = [info["is_success"] if info is not None else False for info in info["final_info"]] successes = [
info["is_success"] if info is not None else False
for info in info["final_info"]
]
else: else:
successes = [False] * env.num_envs successes = [False] * env.num_envs
@ -184,9 +187,13 @@ def rollout(
step += 1 step += 1
running_success_rate = ( running_success_rate = (
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean() einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any")
.numpy()
.mean()
)
progbar.set_postfix(
{"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"}
) )
progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"})
progbar.update() progbar.update()
# Track the final observation. # Track the final observation.
@ -204,7 +211,9 @@ def rollout(
if return_observations: if return_observations:
stacked_observations = {} stacked_observations = {}
for key in all_observations[0]: for key in all_observations[0]:
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1) stacked_observations[key] = torch.stack(
[obs[key] for obs in all_observations], dim=1
)
ret["observation"] = stacked_observations ret["observation"] = stacked_observations
if hasattr(policy, "use_original_modules"): if hasattr(policy, "use_original_modules"):
@ -266,7 +275,9 @@ def eval_policy(
return return
n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs) n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs)
if isinstance(env, gym.vector.SyncVectorEnv): if isinstance(env, gym.vector.SyncVectorEnv):
ep_frames.append(np.stack([env.envs[i].render() for i in range(n_to_render_now)])) # noqa: B023 ep_frames.append(
np.stack([env.envs[i].render() for i in range(n_to_render_now)])
) # noqa: B023
elif isinstance(env, gym.vector.AsyncVectorEnv): elif isinstance(env, gym.vector.AsyncVectorEnv):
# Here we must render all frames and discard any we don't need. # Here we must render all frames and discard any we don't need.
ep_frames.append(np.stack(env.call("render")[:n_to_render_now])) ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
@ -278,7 +289,9 @@ def eval_policy(
episode_data: dict | None = None episode_data: dict | None = None
# we dont want progress bar when we use slurm, since it clutters the logs # we dont want progress bar when we use slurm, since it clutters the logs
progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm()) progbar = trange(
n_batches, desc="Stepping through eval batches", disable=inside_slurm()
)
for batch_ix in progbar: for batch_ix in progbar:
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout # Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
# step. # step.
@ -289,7 +302,8 @@ def eval_policy(
seeds = None seeds = None
else: else:
seeds = range( seeds = range(
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs) start_seed + (batch_ix * env.num_envs),
start_seed + ((batch_ix + 1) * env.num_envs),
) )
rollout_data = rollout( rollout_data = rollout(
env, env,
@ -307,13 +321,22 @@ def eval_policy(
# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done # Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step. # (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int() mask = (
torch.arange(n_steps)
<= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)
).int()
# Extend metrics. # Extend metrics.
batch_sum_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "sum") batch_sum_rewards = einops.reduce(
(rollout_data["reward"] * mask), "b n -> b", "sum"
)
sum_rewards.extend(batch_sum_rewards.tolist()) sum_rewards.extend(batch_sum_rewards.tolist())
batch_max_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "max") batch_max_rewards = einops.reduce(
(rollout_data["reward"] * mask), "b n -> b", "max"
)
max_rewards.extend(batch_max_rewards.tolist()) max_rewards.extend(batch_max_rewards.tolist())
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any") batch_successes = einops.reduce(
(rollout_data["success"] * mask), "b n -> b", "any"
)
all_successes.extend(batch_successes.tolist()) all_successes.extend(batch_successes.tolist())
if seeds: if seeds:
all_seeds.extend(seeds) all_seeds.extend(seeds)
@ -326,17 +349,27 @@ def eval_policy(
rollout_data, rollout_data,
done_indices, done_indices,
start_episode_index=batch_ix * env.num_envs, start_episode_index=batch_ix * env.num_envs,
start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)), start_data_index=(
0
if episode_data is None
else (episode_data["index"][-1].item() + 1)
),
fps=env.unwrapped.metadata["render_fps"], fps=env.unwrapped.metadata["render_fps"],
) )
if episode_data is None: if episode_data is None:
episode_data = this_episode_data episode_data = this_episode_data
else: else:
# Some sanity checks to make sure we are correctly compiling the data. # Some sanity checks to make sure we are correctly compiling the data.
assert episode_data["episode_index"][-1] + 1 == this_episode_data["episode_index"][0] assert (
episode_data["episode_index"][-1] + 1
== this_episode_data["episode_index"][0]
)
assert episode_data["index"][-1] + 1 == this_episode_data["index"][0] assert episode_data["index"][-1] + 1 == this_episode_data["index"][0]
# Concatenate the episode data. # Concatenate the episode data.
episode_data = {k: torch.cat([episode_data[k], this_episode_data[k]]) for k in episode_data} episode_data = {
k: torch.cat([episode_data[k], this_episode_data[k]])
for k in episode_data
}
# Maybe render video for visualization. # Maybe render video for visualization.
if max_episodes_rendered > 0 and len(ep_frames) > 0: if max_episodes_rendered > 0 and len(ep_frames) > 0:
@ -354,7 +387,9 @@ def eval_policy(
target=write_video, target=write_video,
args=( args=(
str(video_path), str(video_path),
stacked_frames[: done_index + 1], # + 1 to capture the last observation stacked_frames[
: done_index + 1
], # + 1 to capture the last observation
env.unwrapped.metadata["render_fps"], env.unwrapped.metadata["render_fps"],
), ),
) )
@ -363,7 +398,9 @@ def eval_policy(
n_episodes_rendered += 1 n_episodes_rendered += 1
progbar.set_postfix( progbar.set_postfix(
{"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"} {
"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"
}
) )
# Wait till all video rendering threads are done. # Wait till all video rendering threads are done.
@ -409,7 +446,11 @@ def eval_policy(
def _compile_episode_data( def _compile_episode_data(
rollout_data: dict, done_indices: Tensor, start_episode_index: int, start_data_index: int, fps: float rollout_data: dict,
done_indices: Tensor,
start_episode_index: int,
start_data_index: int,
fps: float,
) -> dict: ) -> dict:
"""Convenience function for `eval_policy(return_episode_data=True)` """Convenience function for `eval_policy(return_episode_data=True)`
@ -427,12 +468,16 @@ def _compile_episode_data(
# Here we do `num_frames - 1` as we don't want to include the last observation frame just yet. # Here we do `num_frames - 1` as we don't want to include the last observation frame just yet.
ep_dict = { ep_dict = {
"action": rollout_data["action"][ep_ix, : num_frames - 1], "action": rollout_data["action"][ep_ix, : num_frames - 1],
"episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)), "episode_index": torch.tensor(
[start_episode_index + ep_ix] * (num_frames - 1)
),
"frame_index": torch.arange(0, num_frames - 1, 1), "frame_index": torch.arange(0, num_frames - 1, 1),
"timestamp": torch.arange(0, num_frames - 1, 1) / fps, "timestamp": torch.arange(0, num_frames - 1, 1) / fps,
"next.done": rollout_data["done"][ep_ix, : num_frames - 1], "next.done": rollout_data["done"][ep_ix, : num_frames - 1],
"next.success": rollout_data["success"][ep_ix, : num_frames - 1], "next.success": rollout_data["success"][ep_ix, : num_frames - 1],
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32), "next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(
torch.float32
),
} }
# For the last observation frame, all other keys will just be copy padded. # For the last observation frame, all other keys will just be copy padded.
@ -448,7 +493,9 @@ def _compile_episode_data(
for key in ep_dicts[0]: for key in ep_dicts[0]:
data_dict[key] = torch.cat([x[key] for x in ep_dicts]) data_dict[key] = torch.cat([x[key] for x in ep_dicts])
data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1) data_dict["index"] = torch.arange(
start_data_index, start_data_index + total_frames, 1
)
return data_dict return data_dict

View File

@ -46,7 +46,11 @@ import torch
from tqdm import trange from tqdm import trange
from lerobot.common.policies.policy_protocol import Policy from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position from lerobot.common.robot_devices.control_utils import (
busy_wait,
is_headless,
reset_follower_position,
)
from lerobot.common.robot_devices.robots.factory import Robot, make_robot from lerobot.common.robot_devices.robots.factory import Robot, make_robot
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
init_hydra_config, init_hydra_config,
@ -60,13 +64,19 @@ def get_classifier(pretrained_path, config_path):
return return
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier ClassifierConfig,
)
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
)
cfg = init_hydra_config(config_path) cfg = init_hydra_config(config_path)
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg) classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths classifier_config.num_cameras = len(
cfg.training.image_keys
) # TODO automate these paths
model = Classifier(classifier_config) model = Classifier(classifier_config)
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict()) model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
model = model.to("mps") model = model.to("mps")
@ -151,11 +161,17 @@ def rollout(
images = [] images = []
for key in image_keys: for key in image_keys:
if display_cameras: if display_cameras:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) cv2.imshow(
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
)
cv2.waitKey(1) cv2.waitKey(1)
images.append(observation[key].to("mps")) images.append(observation[key].to("mps"))
reward = reward_classifier.predict_reward(images) if reward_classifier is not None else 0.0 reward = (
reward_classifier.predict_reward(images)
if reward_classifier is not None
else 0.0
)
all_rewards.append(reward) all_rewards.append(reward)
# print("REWARD : ", reward) # print("REWARD : ", reward)
@ -219,11 +235,19 @@ def eval_policy(
start_eval = time.perf_counter() start_eval = time.perf_counter()
progbar = trange(n_episodes, desc="Evaluating policy on real robot") progbar = trange(n_episodes, desc="Evaluating policy on real robot")
reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file) reward_classifier = get_classifier(
reward_classifier_pretrained_path, reward_classifier_config_file
)
for _ in progbar: for _ in progbar:
rollout_data = rollout( rollout_data = rollout(
robot, policy, reward_classifier, fps, control_time_s, use_amp, display_cameras robot,
policy,
reward_classifier,
fps,
control_time_s,
use_amp,
display_cameras,
) )
rollouts.append(rollout_data) rollouts.append(rollout_data)
@ -289,7 +313,9 @@ def init_keyboard_listener():
print("Right arrow key pressed. Exiting loop...") print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True events["exit_early"] = True
elif key == keyboard.Key.left: elif key == keyboard.Key.left:
print("Left arrow key pressed. Exiting loop and rerecord the last episode...") print(
"Left arrow key pressed. Exiting loop and rerecord the last episode..."
)
events["rerecord_episode"] = True events["rerecord_episode"] = True
events["exit_early"] = True events["exit_early"] = True
elif key == keyboard.Key.space: elif key == keyboard.Key.space:
@ -301,7 +327,10 @@ def init_keyboard_listener():
"Place the leader in similar pose to the follower and press space again." "Place the leader in similar pose to the follower and press space again."
) )
events["pause_policy"] = True events["pause_policy"] = True
log_say("Human intervention stage. Get ready to take over.", play_sounds=True) log_say(
"Human intervention stage. Get ready to take over.",
play_sounds=True,
)
else: else:
events["human_intervention_step"] = True events["human_intervention_step"] = True
print("Space key pressed. Human intervention starting.") print("Space key pressed. Human intervention starting.")
@ -351,7 +380,9 @@ if __name__ == "__main__":
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)." "debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
), ),
) )
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.") parser.add_argument(
"--revision", help="Optionally provide the Hugging Face Hub revision ID."
)
parser.add_argument( parser.add_argument(
"--out-dir", "--out-dir",
help=( help=(
@ -360,7 +391,8 @@ if __name__ == "__main__":
), ),
) )
parser.add_argument( parser.add_argument(
"--display-cameras", help=("Whether to display the camera feed while the rollout is happening") "--display-cameras",
help=("Whether to display the camera feed while the rollout is happening"),
) )
parser.add_argument( parser.add_argument(
"--reward-classifier-pretrained-path", "--reward-classifier-pretrained-path",

View File

@ -45,9 +45,13 @@ def find_port():
print(f"The port of this MotorsBus is '{port}'") print(f"The port of this MotorsBus is '{port}'")
print("Reconnect the USB cable.") print("Reconnect the USB cable.")
elif len(ports_diff) == 0: elif len(ports_diff) == 0:
raise OSError(f"Could not detect the port. No difference was found ({ports_diff}).") raise OSError(
f"Could not detect the port. No difference was found ({ports_diff})."
)
else: else:
raise OSError(f"Could not detect the port. More than one port was found ({ports_diff}).") raise OSError(
f"Could not detect the port. More than one port was found ({ports_diff})."
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import functools import functools
import random
from typing import Any, Callable, Optional, Sequence, TypedDict from typing import Any, Callable, Optional, Sequence, TypedDict
import io import io
@ -737,7 +736,6 @@ def concatenate_batch_transitions(
if __name__ == "__main__": if __name__ == "__main__":
import numpy as np
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
# ===== Test 1: Create and use a synthetic ReplayBuffer ===== # ===== Test 1: Create and use a synthetic ReplayBuffer =====
@ -1139,7 +1137,7 @@ if __name__ == "__main__":
savings_percent = (std_mem - opt_mem) / std_mem * 100 savings_percent = (std_mem - opt_mem) / std_mem * 100
print(f"\nMemory optimization result:") print("\nMemory optimization result:")
print(f"- Standard buffer state memory: {std_mem / (1024 * 1024):.2f} MB") print(f"- Standard buffer state memory: {std_mem / (1024 * 1024):.2f} MB")
print(f"- Optimized buffer state memory: {opt_mem / (1024 * 1024):.2f} MB") print(f"- Optimized buffer state memory: {opt_mem / (1024 * 1024):.2f} MB")
print(f"- Memory savings for state tensors: {savings_percent:.1f}%") print(f"- Memory savings for state tensors: {savings_percent:.1f}%")

View File

@ -225,7 +225,9 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.") parser = argparse.ArgumentParser(
description="Crop rectangular ROIs from a LeRobot dataset."
)
parser.add_argument( parser.add_argument(
"--repo-id", "--repo-id",
type=str, type=str,
@ -247,7 +249,9 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
local_files_only = args.root is not None local_files_only = args.root is not None
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=local_files_only) dataset = LeRobotDataset(
repo_id=args.repo_id, root=args.root, local_files_only=local_files_only
)
images = get_image_from_lerobot_dataset(dataset) images = get_image_from_lerobot_dataset(dataset)
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()} images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
@ -256,7 +260,7 @@ if __name__ == "__main__":
if args.crop_params_path is None: if args.crop_params_path is None:
rois = select_square_roi_for_images(images) rois = select_square_roi_for_images(images)
else: else:
with open(args.crop_params_path, "r") as f: with open(args.crop_params_path) as f:
rois = json.load(f) rois = json.load(f)
# rois = { # rois = {

View File

@ -31,7 +31,9 @@ def find_joint_bounds(
if display_cameras and not is_headless(): if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key] image_keys = [key for key in observation if "image" in key]
for key in image_keys: for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) cv2.imshow(
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
)
cv2.waitKey(1) cv2.waitKey(1)
timestamp = time.perf_counter() - start_episode_t timestamp = time.perf_counter() - start_episode_t
@ -57,7 +59,12 @@ if __name__ == "__main__":
nargs="*", nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)", help="Any key=value arguments to override config values (use dots for.nested=overrides)",
) )
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds") parser.add_argument(
"--control-time-s",
type=float,
default=20,
help="Maximum episode length in seconds",
)
args = parser.parse_args() args = parser.parse_args()
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides) robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)

View File

@ -146,7 +146,7 @@ def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
def initialize_replay_buffer( def initialize_replay_buffer(
cfg: DictConfig, logger: Logger, device: str, storage_device:str cfg: DictConfig, logger: Logger, device: str, storage_device: str
) -> ReplayBuffer: ) -> ReplayBuffer:
if not cfg.resume: if not cfg.resume:
return ReplayBuffer( return ReplayBuffer(

View File

@ -10,7 +10,9 @@ from typing import Any
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, torch.Tensor]: def preprocess_maniskill_observation(
observations: dict[str, np.ndarray],
) -> dict[str, torch.Tensor]:
"""Convert environment observation to LeRobot format observation. """Convert environment observation to LeRobot format observation.
Args: Args:
observation: Dictionary of observation batches from a Gym vector environment. observation: Dictionary of observation batches from a Gym vector environment.
@ -62,7 +64,9 @@ class ManiSkillCompat(gym.Wrapper):
new_action_space_shape = env.action_space.shape[-1] new_action_space_shape = env.action_space.shape[-1]
new_low = np.squeeze(env.action_space.low, axis=0) new_low = np.squeeze(env.action_space.low, axis=0)
new_high = np.squeeze(env.action_space.high, axis=0) new_high = np.squeeze(env.action_space.high, axis=0)
self.action_space = gym.spaces.Box(low=new_low, high=new_high, shape=(new_action_space_shape,)) self.action_space = gym.spaces.Box(
low=new_low, high=new_high, shape=(new_action_space_shape,)
)
def reset( def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None self, *, seed: int | None = None, options: dict[str, Any] | None = None
@ -81,7 +85,9 @@ class ManiSkillCompat(gym.Wrapper):
class ManiSkillActionWrapper(gym.ActionWrapper): class ManiSkillActionWrapper(gym.ActionWrapper):
def __init__(self, env): def __init__(self, env):
super().__init__(env) super().__init__(env)
self.action_space = gym.spaces.Tuple(spaces=(env.action_space, gym.spaces.Discrete(2))) self.action_space = gym.spaces.Tuple(
spaces=(env.action_space, gym.spaces.Discrete(2))
)
def action(self, action): def action(self, action):
action, telop = action action, telop = action
@ -95,7 +101,9 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper):
action_space_agent: gym.spaces.Box = env.action_space[0] action_space_agent: gym.spaces.Box = env.action_space[0]
action_space_agent.low = action_space_agent.low * multiply_factor action_space_agent.low = action_space_agent.low * multiply_factor
action_space_agent.high = action_space_agent.high * multiply_factor action_space_agent.high = action_space_agent.high * multiply_factor
self.action_space = gym.spaces.Tuple(spaces=(action_space_agent, gym.spaces.Discrete(2))) self.action_space = gym.spaces.Tuple(
spaces=(action_space_agent, gym.spaces.Discrete(2))
)
def step(self, action): def step(self, action):
if isinstance(action, tuple): if isinstance(action, tuple):
@ -137,7 +145,9 @@ def make_maniskill(
env = ManiSkillObservationWrapper(env, device=cfg.env.device) env = ManiSkillObservationWrapper(env, device=cfg.env.device)
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False) env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env) env._max_episode_steps = env.max_episode_steps = (
50 # gym_utils.find_max_episode_steps_value(env)
)
env.unwrapped.metadata["render_fps"] = 20 env.unwrapped.metadata["render_fps"] = 20
env = ManiSkillCompat(env) env = ManiSkillCompat(env)
env = ManiSkillActionWrapper(env) env = ManiSkillActionWrapper(env)
@ -149,10 +159,11 @@ def make_maniskill(
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
import hydra import hydra
from omegaconf import OmegaConf
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="lerobot/configs/env/maniskill_example.yaml") parser.add_argument(
"--config", type=str, default="lerobot/configs/env/maniskill_example.yaml"
)
args = parser.parse_args() args = parser.parse_args()
# Initialize config # Initialize config

View File

@ -73,7 +73,9 @@ def make_optimizer_and_scheduler(cfg, policy):
}, },
] ]
optimizer = torch.optim.AdamW( optimizer = torch.optim.AdamW(
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay optimizer_params_dicts,
lr=cfg.training.lr,
weight_decay=cfg.training.weight_decay,
) )
lr_scheduler = None lr_scheduler = None
elif cfg.policy.name == "diffusion": elif cfg.policy.name == "diffusion":
@ -100,14 +102,23 @@ def make_optimizer_and_scheduler(cfg, policy):
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
[ [
{"params": policy.actor.parameters(), "lr": policy.config.actor_lr}, {"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
{"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr}, {
{"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr}, "params": policy.critic_ensemble.parameters(),
"lr": policy.config.critic_lr,
},
{
"params": policy.temperature.parameters(),
"lr": policy.config.temperature_lr,
},
] ]
) )
lr_scheduler = None lr_scheduler = None
elif cfg.policy.name == "vqbet": elif cfg.policy.name == "vqbet":
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler from lerobot.common.policies.vqbet.modeling_vqbet import (
VQBeTOptimizer,
VQBeTScheduler,
)
optimizer = VQBeTOptimizer(policy, cfg) optimizer = VQBeTOptimizer(policy, cfg)
lr_scheduler = VQBeTScheduler(optimizer, cfg) lr_scheduler = VQBeTScheduler(optimizer, cfg)
@ -214,7 +225,9 @@ def train(cfg: TrainPipelineConfig):
if cfg.resume: if cfg.resume:
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler) step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_learnable_params = sum(
p.numel() for p in policy.parameters() if p.requires_grad
)
num_total_params = sum(p.numel() for p in policy.parameters()) num_total_params = sum(p.numel() for p in policy.parameters())
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")

View File

@ -14,7 +14,6 @@
import logging import logging
import time import time
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path
from pprint import pformat from pprint import pformat
import hydra import hydra
@ -28,14 +27,16 @@ from termcolor import colored
from torch import optim from torch import optim
from torch.autograd import profiler from torch.autograd import profiler
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler, random_split from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler
from tqdm import tqdm from tqdm import tqdm
from lerobot.common.datasets.factory import resolve_delta_timestamps from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger from lerobot.common.logger import Logger
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
ClassifierConfig,
)
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
format_big_number, format_big_number,
@ -50,7 +51,11 @@ def get_model(cfg, logger): # noqa I001
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg) classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
model = Classifier(classifier_config) model = Classifier(classifier_config)
if cfg.resume: if cfg.resume:
model.load_state_dict(Classifier.from_pretrained(str(logger.last_pretrained_model_dir)).state_dict()) model.load_state_dict(
Classifier.from_pretrained(
str(logger.last_pretrained_model_dir)
).state_dict()
)
return model return model
@ -62,7 +67,9 @@ def create_balanced_sampler(dataset, cfg):
class_weights = 1.0 / counts.float() class_weights = 1.0 / counts.float()
sample_weights = class_weights[labels] sample_weights = class_weights[labels]
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True) return WeightedRandomSampler(
weights=sample_weights, num_samples=len(sample_weights), replacement=True
)
def support_amp(device: torch.device, cfg: DictConfig) -> bool: def support_amp(device: torch.device, cfg: DictConfig) -> bool:
@ -71,7 +78,9 @@ def support_amp(device: torch.device, cfg: DictConfig) -> bool:
return cfg.training.use_amp and device.type in ("cuda", "cpu") return cfg.training.use_amp and device.type in ("cuda", "cpu")
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg): def train_epoch(
model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg
):
# Single epoch training loop with AMP support and progress tracking # Single epoch training loop with AMP support and progress tracking
model.train() model.train()
correct = 0 correct = 0
@ -85,7 +94,11 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
labels = batch[cfg.training.label_key].float().to(device) labels = batch[cfg.training.label_key].float().to(device)
# Forward pass with optional AMP # Forward pass with optional AMP
with torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(): with (
torch.autocast(device_type=device.type)
if support_amp(device, cfg)
else nullcontext()
):
outputs = model(images) outputs = model(images)
loss = criterion(outputs.logits, labels) loss = criterion(outputs.logits, labels)
@ -130,7 +143,9 @@ def validate(model, val_loader, criterion, device, logger, cfg):
with ( with (
torch.no_grad(), torch.no_grad(),
torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(), torch.autocast(device_type=device.type)
if support_amp(device, cfg)
else nullcontext(),
): ):
for batch in tqdm(val_loader, desc="Validation"): for batch in tqdm(val_loader, desc="Validation"):
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys] images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
@ -143,7 +158,9 @@ def validate(model, val_loader, criterion, device, logger, cfg):
): ):
outputs = model(images) outputs = model(images)
inference_times.append( inference_times.append(
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time next(
x for x in prof.key_averages() if x.key == "model_inference"
).cpu_time
) )
else: else:
outputs = model(images) outputs = model(images)
@ -161,16 +178,24 @@ def validate(model, val_loader, criterion, device, logger, cfg):
# Log sample predictions for visualization # Log sample predictions for visualization
if len(samples) < cfg.eval.num_samples_to_log: if len(samples) < cfg.eval.num_samples_to_log:
for i in range(min(cfg.eval.num_samples_to_log - len(samples), len(images))): for i in range(
min(cfg.eval.num_samples_to_log - len(samples), len(images))
):
if model.config.num_classes == 2: if model.config.num_classes == 2:
confidence = round(outputs.probabilities[i].item(), 3) confidence = round(outputs.probabilities[i].item(), 3)
else: else:
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()] confidence = [
round(prob, 3) for prob in outputs.probabilities[i].tolist()
]
samples.append( samples.append(
{ {
**{ **{
f"image_{img_key}": wandb.Image(images[img_idx][i].cpu()) f"image_{img_key}": wandb.Image(
for img_idx, img_key in enumerate(cfg.training.image_keys) images[img_idx][i].cpu()
)
for img_idx, img_key in enumerate(
cfg.training.image_keys
)
}, },
"true_label": labels[i].item(), "true_label": labels[i].item(),
"predicted": predictions[i].item(), "predicted": predictions[i].item(),
@ -238,15 +263,24 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step):
elif device.type == "mps": elif device.type == "mps":
torch.mps.synchronize() torch.mps.synchronize()
with profiler.profile(record_shapes=True) as prof, profiler.record_function("model_inference"): with (
profiler.profile(record_shapes=True) as prof,
profiler.record_function("model_inference"),
):
_ = model(x) _ = model(x)
inference_times.append( inference_times.append(
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time next(
x for x in prof.key_averages() if x.key == "model_inference"
).cpu_time
) )
inference_times = np.array(inference_times) inference_times = np.array(inference_times)
avg, median, std = inference_times.mean(), np.median(inference_times), inference_times.std() avg, median, std = (
inference_times.mean(),
np.median(inference_times),
inference_times.std(),
)
print( print(
f"Inference time mean: {avg:.2f} us, median: {median:.2f} us, std: {std:.2f} us, with {iters} iterations on {device.type} device" f"Inference time mean: {avg:.2f} us, median: {median:.2f} us, std: {std:.2f} us, with {iters} iterations on {device.type} device"
) )
@ -264,7 +298,11 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step):
return avg, median, std return avg, median, std
@hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier") @hydra.main(
version_base="1.2",
config_path="../configs/policy",
config_name="hilserl_classifier",
)
def train(cfg: DictConfig) -> None: def train(cfg: DictConfig) -> None:
# Main training pipeline with support for resuming training # Main training pipeline with support for resuming training
logging.info(OmegaConf.to_yaml(cfg)) logging.info(OmegaConf.to_yaml(cfg))
@ -278,7 +316,9 @@ def train(cfg: DictConfig) -> None:
# Setup dataset and dataloaders # Setup dataset and dataloaders
dataset = LeRobotDataset( dataset = LeRobotDataset(
cfg.dataset_repo_id, root=cfg.dataset_root, local_files_only=cfg.local_files_only cfg.dataset_repo_id,
root=cfg.dataset_root,
local_files_only=cfg.local_files_only,
) )
logging.info(f"Dataset size: {len(dataset)}") logging.info(f"Dataset size: {len(dataset)}")
@ -314,7 +354,9 @@ def train(cfg: DictConfig) -> None:
"You have set resume=True, but there is no model checkpoint in " "You have set resume=True, but there is no model checkpoint in "
f"{Logger.get_last_checkpoint_dir(out_dir)}" f"{Logger.get_last_checkpoint_dir(out_dir)}"
) )
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml") checkpoint_cfg_path = str(
Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml"
)
logging.info( logging.info(
colored( colored(
"You have set resume=True, indicating that you wish to resume a run", "You have set resume=True, indicating that you wish to resume a run",
@ -327,7 +369,9 @@ def train(cfg: DictConfig) -> None:
# Check for differences between the checkpoint configuration and provided configuration. # Check for differences between the checkpoint configuration and provided configuration.
# Hack to resolve the delta_timestamps ahead of time in order to properly diff. # Hack to resolve the delta_timestamps ahead of time in order to properly diff.
resolve_delta_timestamps(cfg) resolve_delta_timestamps(cfg)
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)) diff = DeepDiff(
OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)
)
# Ignore the `resume` and parameters. # Ignore the `resume` and parameters.
if "values_changed" in diff and "root['resume']" in diff["values_changed"]: if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
del diff["values_changed"]["root['resume']"] del diff["values_changed"]["root['resume']"]
@ -346,7 +390,11 @@ def train(cfg: DictConfig) -> None:
optimizer = optim.AdamW(model.parameters(), lr=cfg.training.learning_rate) optimizer = optim.AdamW(model.parameters(), lr=cfg.training.learning_rate)
# Use BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multi-class # Use BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multi-class
criterion = nn.BCEWithLogitsLoss() if model.config.num_classes == 2 else nn.CrossEntropyLoss() criterion = (
nn.BCEWithLogitsLoss()
if model.config.num_classes == 2
else nn.CrossEntropyLoss()
)
grad_scaler = GradScaler(enabled=cfg.training.use_amp) grad_scaler = GradScaler(enabled=cfg.training.use_amp)
# Log model parameters # Log model parameters
@ -362,7 +410,17 @@ def train(cfg: DictConfig) -> None:
for epoch in range(cfg.training.num_epochs): for epoch in range(cfg.training.num_epochs):
logging.info(f"\nEpoch {epoch+1}/{cfg.training.num_epochs}") logging.info(f"\nEpoch {epoch+1}/{cfg.training.num_epochs}")
train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg) train_epoch(
model,
train_loader,
criterion,
optimizer,
grad_scaler,
device,
logger,
step,
cfg,
)
# Periodic validation # Periodic validation
if cfg.training.eval_freq > 0 and (epoch + 1) % cfg.training.eval_freq == 0: if cfg.training.eval_freq > 0 and (epoch + 1) % cfg.training.eval_freq == 0:

View File

@ -22,7 +22,6 @@ from typing import Callable, Optional, Sequence, TypedDict
import hydra import hydra
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from torch import nn from torch import nn
from tqdm import tqdm from tqdm import tqdm
@ -30,20 +29,17 @@ from tqdm import tqdm
# TODO: Remove the import of maniskill # TODO: Remove the import of maniskill
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.envs.factory import make_env, make_maniskill_env from lerobot.common.envs.factory import make_maniskill_env
from lerobot.common.envs.utils import preprocess_maniskill_observation, preprocess_observation from lerobot.common.envs.utils import preprocess_maniskill_observation
from lerobot.common.logger import Logger, log_output_dir from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.utils import ( from lerobot.common.utils.utils import (
format_big_number, format_big_number,
get_safe_torch_device, get_safe_torch_device,
init_hydra_config,
init_logging, init_logging,
set_global_seed, set_global_seed,
) )
from lerobot.scripts.eval import eval_policy
def make_optimizers_and_scheduler(cfg, policy): def make_optimizers_and_scheduler(cfg, policy):
@ -56,7 +52,9 @@ def make_optimizers_and_scheduler(cfg, policy):
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
) )
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module # We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr) optimizer_temperature = torch.optim.Adam(
params=[policy.log_alpha], lr=policy.config.critic_lr
)
lr_scheduler = None lr_scheduler = None
optimizers = { optimizers = {
"actor": optimizer_actor, "actor": optimizer_actor,
@ -108,7 +106,9 @@ def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Te
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C) images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
# Gather pixels # Gather pixels
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :] cropped_hwcn = images_hwcn[
torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :
]
# cropped_hwcn => (B, crop_h, crop_w, C) # cropped_hwcn => (B, crop_h, crop_w, C)
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w) cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
@ -198,8 +198,12 @@ class ReplayBuffer:
""" """
# We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from # We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from
# a replay buffer than from a lerobot dataset. # a replay buffer than from a lerobot dataset.
replay_buffer = cls(capacity=len(lerobot_dataset), device=device, state_keys=state_keys) replay_buffer = cls(
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys) capacity=len(lerobot_dataset), device=device, state_keys=state_keys
)
list_transition = cls._lerobotdataset_to_transitions(
dataset=lerobot_dataset, state_keys=state_keys
)
# Fill the replay buffer with the lerobot dataset transitions # Fill the replay buffer with the lerobot dataset transitions
for data in list_transition: for data in list_transition:
replay_buffer.add( replay_buffer.add(
@ -244,7 +248,9 @@ class ReplayBuffer:
# If not provided, you can either raise an error or define a default: # If not provided, you can either raise an error or define a default:
if state_keys is None: if state_keys is None:
raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.") raise ValueError(
"You must provide a list of keys in `state_keys` that define your 'state'."
)
transitions: list[Transition] = [] transitions: list[Transition] = []
num_frames = len(dataset) num_frames = len(dataset)
@ -298,36 +304,40 @@ class ReplayBuffer:
# -- Build batched states -- # -- Build batched states --
batch_state = {} batch_state = {}
for key in self.state_keys: for key in self.state_keys:
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to( batch_state[key] = torch.cat(
self.device [t["state"][key] for t in list_of_transitions], dim=0
) ).to(self.device)
if key.startswith("observation.image") and self.use_drq: if key.startswith("observation.image") and self.use_drq:
batch_state[key] = self.image_augmentation_function(batch_state[key]) batch_state[key] = self.image_augmentation_function(batch_state[key])
# -- Build batched actions -- # -- Build batched actions --
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device) batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(
# -- Build batched rewards --
batch_rewards = torch.tensor([t["reward"] for t in list_of_transitions], dtype=torch.float32).to(
self.device self.device
) )
# -- Build batched rewards --
batch_rewards = torch.tensor(
[t["reward"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
# -- Build batched next states -- # -- Build batched next states --
batch_next_state = {} batch_next_state = {}
for key in self.state_keys: for key in self.state_keys:
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to( batch_next_state[key] = torch.cat(
self.device [t["next_state"][key] for t in list_of_transitions], dim=0
) ).to(self.device)
if key.startswith("observation.image") and self.use_drq: if key.startswith("observation.image") and self.use_drq:
batch_next_state[key] = self.image_augmentation_function(batch_next_state[key]) batch_next_state[key] = self.image_augmentation_function(
batch_next_state[key]
)
# -- Build batched dones -- # -- Build batched dones --
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to( batch_dones = torch.tensor(
self.device [t["done"] for t in list_of_transitions], dtype=torch.float32
) ).to(self.device)
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to( batch_dones = torch.tensor(
self.device [t["done"] for t in list_of_transitions], dtype=torch.float32
) ).to(self.device)
# Return a BatchTransition typed dict # Return a BatchTransition typed dict
return BatchTransition( return BatchTransition(
@ -344,7 +354,13 @@ def concatenate_batch_transitions(
) -> BatchTransition: ) -> BatchTransition:
"""NOTE: Be careful it change the left_batch_transitions in place""" """NOTE: Be careful it change the left_batch_transitions in place"""
left_batch_transitions["state"] = { left_batch_transitions["state"] = {
key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0) key: torch.cat(
[
left_batch_transitions["state"][key],
right_batch_transition["state"][key],
],
dim=0,
)
for key in left_batch_transitions["state"] for key in left_batch_transitions["state"]
} }
left_batch_transitions["action"] = torch.cat( left_batch_transitions["action"] = torch.cat(
@ -355,7 +371,11 @@ def concatenate_batch_transitions(
) )
left_batch_transitions["next_state"] = { left_batch_transitions["next_state"] = {
key: torch.cat( key: torch.cat(
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0 [
left_batch_transitions["next_state"][key],
right_batch_transition["next_state"][key],
],
dim=0,
) )
for key in left_batch_transitions["next_state"] for key in left_batch_transitions["next_state"]
} }
@ -407,7 +427,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None, # dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online traning, we do not need dataset_stats # Hack: But if we do online traning, we do not need dataset_stats
dataset_stats=None, dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None, pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
if cfg.resume
else None,
device=device, device=device,
) )
assert isinstance(policy, nn.Module) assert isinstance(policy, nn.Module)
@ -416,7 +438,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# TODO: Handle resume # TODO: Handle resume
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_learnable_params = sum(
p.numel() for p in policy.parameters() if p.requires_grad
)
num_total_params = sum(p.numel() for p in policy.parameters()) num_total_params = sum(p.numel() for p in policy.parameters())
log_output_dir(out_dir) log_output_dir(out_dir)
@ -433,7 +457,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
obs = {key: obs[key].to(device, non_blocking=True) for key in obs} obs = {key: obs[key].to(device, non_blocking=True) for key in obs}
replay_buffer = ReplayBuffer( replay_buffer = ReplayBuffer(
capacity=cfg.training.online_buffer_capacity, device=device, state_keys=cfg.policy.input_shapes.keys() capacity=cfg.training.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
) )
batch_size = cfg.training.batch_size batch_size = cfg.training.batch_size
@ -455,12 +481,16 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
if interaction_step >= cfg.training.online_step_before_learning: if interaction_step >= cfg.training.online_step_before_learning:
action = policy.select_action(batch=obs) action = policy.select_action(batch=obs)
next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy()) next_obs, reward, done, truncated, info = online_env.step(
action.cpu().numpy()
)
else: else:
action = online_env.action_space.sample() action = online_env.action_space.sample()
next_obs, reward, done, truncated, info = online_env.step(action) next_obs, reward, done, truncated, info = online_env.step(action)
# HACK # HACK
action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True) action = torch.tensor(action, dtype=torch.float32).to(
device, non_blocking=True
)
# HACK: For maniskill # HACK: For maniskill
# next_obs = preprocess_observation(next_obs) # next_obs = preprocess_observation(next_obs)
@ -470,14 +500,20 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Because we are using a single environment # Because we are using a single environment
# we can safely assume that the episode is done # we can safely assume that the episode is done
if done[0] or truncated[0]: if done[0] or truncated[0]:
logging.info(f"Global step {interaction_step}: Episode reward: {sum_reward_episode}") logging.info(
logger.log_dict({"Sum episode reward": sum_reward_episode}, interaction_step) f"Global step {interaction_step}: Episode reward: {sum_reward_episode}"
)
logger.log_dict(
{"Sum episode reward": sum_reward_episode}, interaction_step
)
sum_reward_episode = 0 sum_reward_episode = 0
# HACK: This is for maniskill # HACK: This is for maniskill
logging.info( logging.info(
f"global step {interaction_step}: episode success: {info['success'].float().item()} \n" f"global step {interaction_step}: episode success: {info['success'].float().item()} \n"
) )
logger.log_dict({"Episode success": info["success"].float().item()}, interaction_step) logger.log_dict(
{"Episode success": info["success"].float().item()}, interaction_step
)
replay_buffer.add( replay_buffer.add(
state=obs, state=obs,
@ -551,7 +587,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
training_infos["loss_actor"] = loss_actor.item() training_infos["loss_actor"] = loss_actor.item()
loss_temperature = policy.compute_loss_temperature(observations=observations) loss_temperature = policy.compute_loss_temperature(
observations=observations
)
optimizers["temperature"].zero_grad() optimizers["temperature"].zero_grad()
loss_temperature.backward() loss_temperature.backward()
optimizers["temperature"].step() optimizers["temperature"].step()
@ -573,7 +611,9 @@ def train_cli(cfg: dict):
) )
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"): def train_notebook(
out_dir=None, job_name=None, config_name="default", config_path="../configs"
):
from hydra import compose, initialize from hydra import compose, initialize
hydra.core.global_hydra.GlobalHydra.instance().clear() hydra.core.global_hydra.GlobalHydra.instance().clear()

View File

@ -94,8 +94,12 @@ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
assert chw_float32_torch.dtype == torch.float32 assert chw_float32_torch.dtype == torch.float32
assert chw_float32_torch.ndim == 3 assert chw_float32_torch.ndim == 3
c, h, w = chw_float32_torch.shape c, h, w = chw_float32_torch.shape
assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}" assert (
hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy() c < h and c < w
), f"expect channel first images, but instead {chw_float32_torch.shape}"
hwc_uint8_numpy = (
(chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
)
return hwc_uint8_numpy return hwc_uint8_numpy

View File

@ -81,7 +81,11 @@ def run_server(
static_folder: Path, static_folder: Path,
template_folder: Path, template_folder: Path,
): ):
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve()) app = Flask(
__name__,
static_folder=static_folder.resolve(),
template_folder=template_folder.resolve(),
)
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
@app.route("/") @app.route("/")
@ -138,8 +142,12 @@ def run_server(
) )
) )
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>") @app.route(
def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes): "/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>"
)
def show_episode(
dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes
):
repo_id = f"{dataset_namespace}/{dataset_name}" repo_id = f"{dataset_namespace}/{dataset_name}"
try: try:
if dataset is None: if dataset is None:
@ -171,15 +179,21 @@ def run_server(
} }
if isinstance(dataset, LeRobotDataset): if isinstance(dataset, LeRobotDataset):
video_paths = [ video_paths = [
dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys dataset.meta.get_video_file_path(episode_id, key)
for key in dataset.meta.video_keys
] ]
videos_info = [ videos_info = [
{"url": url_for("static", filename=video_path), "filename": video_path.parent.name} {
"url": url_for("static", filename=video_path),
"filename": video_path.parent.name,
}
for video_path in video_paths for video_path in video_paths
] ]
tasks = dataset.meta.episodes[episode_id]["tasks"] tasks = dataset.meta.episodes[episode_id]["tasks"]
else: else:
video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"] video_keys = [
key for key, ft in dataset.features.items() if ft["dtype"] == "video"
]
videos_info = [ videos_info = [
{ {
"url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/" "url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
@ -198,16 +212,24 @@ def run_server(
) )
response.raise_for_status() response.raise_for_status()
# Split into lines and parse each line as JSON # Split into lines and parse each line as JSON
tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()] tasks_jsonl = [
json.loads(line) for line in response.text.splitlines() if line.strip()
]
filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id] filtered_tasks_jsonl = [
row for row in tasks_jsonl if row["episode_index"] == episode_id
]
tasks = filtered_tasks_jsonl[0]["tasks"] tasks = filtered_tasks_jsonl[0]["tasks"]
videos_info[0]["language_instruction"] = tasks videos_info[0]["language_instruction"] = tasks
if episodes is None: if episodes is None:
episodes = list( episodes = list(
range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes) range(
dataset.num_episodes
if isinstance(dataset, LeRobotDataset)
else dataset.total_episodes
)
) )
return render_template( return render_template(
@ -255,7 +277,10 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
else dataset.features[column_name].shape[0] else dataset.features[column_name].shape[0]
) )
if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]: if (
"names" in dataset.features[column_name]
and dataset.features[column_name]["names"]
):
column_names = dataset.features[column_name]["names"] column_names = dataset.features[column_name]["names"]
while not isinstance(column_names, list): while not isinstance(column_names, list):
column_names = list(column_names.values())[0] column_names = list(column_names.values())[0]
@ -278,8 +303,12 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
else: else:
repo_id = dataset.repo_id repo_id = dataset.repo_id
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format( url = (
episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
+ dataset.data_path.format(
episode_chunk=int(episode_index) // dataset.chunks_size,
episode_index=episode_index,
)
) )
df = pd.read_parquet(url) df = pd.read_parquet(url)
data = df[selected_columns] # Select specific columns data = df[selected_columns] # Select specific columns
@ -312,7 +341,9 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
] ]
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]: def get_episode_language_instruction(
dataset: LeRobotDataset, ep_index: int
) -> list[str]:
# check if the dataset has language instructions # check if the dataset has language instructions
if "language_instruction" not in dataset.features: if "language_instruction" not in dataset.features:
return None return None
@ -323,7 +354,9 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) ->
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"] language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored # TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
# with the tf.tensor appearing in the string # with the tf.tensor appearing in the string
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)") return language_instruction.removeprefix("tf.Tensor(b'").removesuffix(
"', shape=(), dtype=string)"
)
def get_dataset_info(repo_id: str) -> IterableNamespace: def get_dataset_info(repo_id: str) -> IterableNamespace:
@ -358,7 +391,9 @@ def visualize_dataset_html(
if force_override: if force_override:
shutil.rmtree(output_dir) shutil.rmtree(output_dir)
else: else:
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'") logging.info(
f"Output directory already exists. Loading from it: '{output_dir}'"
)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)

View File

@ -52,7 +52,13 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors") save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
# save 2 frames at the middle of first episode # save 2 frames at the middle of first episode
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2) i = int(
(
dataset.episode_data_index["to"][0].item()
- dataset.episode_data_index["from"][0].item()
)
/ 2
)
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors") save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors") save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")

View File

@ -30,7 +30,9 @@ class config: # noqa: N801
def enable_device(self, device_id: str): def enable_device(self, device_id: str):
self.device_enabled = device_id self.device_enabled = device_id
def enable_stream(self, stream_type: stream, width=None, height=None, color_format=None, fps=None): def enable_stream(
self, stream_type: stream, width=None, height=None, color_format=None, fps=None
):
self.stream_type = stream_type self.stream_type = stream_type
# Overwrite default values when possible # Overwrite default values when possible
self.width = 848 if width is None else width self.width = 848 if width is None else width

View File

@ -37,7 +37,10 @@ pytest -sx 'tests/test_cameras.py::test_camera[intelrealsense-True]'
import numpy as np import numpy as np
import pytest import pytest
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
)
from tests.utils import TEST_CAMERA_TYPES, make_camera, require_camera from tests.utils import TEST_CAMERA_TYPES, make_camera, require_camera
# Maximum absolute difference between two consecutive images recorded by a camera. # Maximum absolute difference between two consecutive images recorded by a camera.
@ -112,7 +115,11 @@ def test_camera(request, camera_type, mock):
) )
# TODO(rcadene): properly set `rtol` # TODO(rcadene): properly set `rtol`
np.testing.assert_allclose( np.testing.assert_allclose(
color_image, async_color_image, rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE, err_msg=error_msg color_image,
async_color_image,
rtol=1e-5,
atol=MAX_PIXEL_DIFFERENCE,
err_msg=error_msg,
) )
# Test disconnecting # Test disconnecting
@ -131,7 +138,11 @@ def test_camera(request, camera_type, mock):
assert camera.color_mode == "bgr" assert camera.color_mode == "bgr"
bgr_color_image = camera.read() bgr_color_image = camera.read()
np.testing.assert_allclose( np.testing.assert_allclose(
color_image, bgr_color_image[:, :, [2, 1, 0]], rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE, err_msg=error_msg color_image,
bgr_color_image[:, :, [2, 1, 0]],
rtol=1e-5,
atol=MAX_PIXEL_DIFFERENCE,
err_msg=error_msg,
) )
del camera del camera
@ -166,7 +177,11 @@ def test_camera(request, camera_type, mock):
rot_color_image = camera.read() rot_color_image = camera.read()
np.testing.assert_allclose( np.testing.assert_allclose(
rot_color_image, manual_rot_img, rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE, err_msg=error_msg rot_color_image,
manual_rot_img,
rtol=1e-5,
atol=MAX_PIXEL_DIFFERENCE,
err_msg=error_msg,
) )
del camera del camera
@ -200,7 +215,9 @@ def test_save_images_from_cameras(tmp_path, request, camera_type, mock):
if camera_type == "opencv": if camera_type == "opencv":
from lerobot.common.robot_devices.cameras.opencv import save_images_from_cameras from lerobot.common.robot_devices.cameras.opencv import save_images_from_cameras
elif camera_type == "intelrealsense": elif camera_type == "intelrealsense":
from lerobot.common.robot_devices.cameras.intelrealsense import save_images_from_cameras from lerobot.common.robot_devices.cameras.intelrealsense import (
save_images_from_cameras,
)
# Small `record_time_s` to speedup unit tests # Small `record_time_s` to speedup unit tests
save_images_from_cameras(tmp_path, record_time_s=0.02, mock=mock) save_images_from_cameras(tmp_path, record_time_s=0.02, mock=mock)

View File

@ -91,7 +91,12 @@ def patch_builtins_input(monkeypatch):
def pytest_addoption(parser): def pytest_addoption(parser):
parser.addoption("--seed", action="store", default="42", help="Set random seed for reproducibility") parser.addoption(
"--seed",
action="store",
default="42",
help="Set random seed for reproducibility",
)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)

View File

@ -364,10 +364,16 @@ def test_save_each_transform(img_tensor_factory, tmp_path):
for transform in transforms: for transform in transforms:
transform_dir = tmp_path / transform transform_dir = tmp_path / transform
assert transform_dir.exists(), f"{transform} directory was not created." assert transform_dir.exists(), f"{transform} directory was not created."
assert any(transform_dir.iterdir()), f"No transformed images found in {transform} directory." assert any(
transform_dir.iterdir()
), f"No transformed images found in {transform} directory."
# Check for specific files within each transform directory # Check for specific files within each transform directory
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"] expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + [
"min.png",
"max.png",
"mean.png",
]
for file_name in expected_files: for file_name in expected_files:
assert (transform_dir / file_name).exists(), ( assert (transform_dir / file_name).exists(), (
f"{file_name} was not found in {transform} directory." f"{file_name} was not found in {transform} directory."

View File

@ -187,7 +187,9 @@ def test_save_image_torch(tmp_path, img_tensor_factory):
writer.wait_until_done() writer.wait_until_done()
assert fpath.exists() assert fpath.exists()
saved_image = np.array(Image.open(fpath)) saved_image = np.array(Image.open(fpath))
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(
np.uint8
)
assert np.array_equal(expected_image, saved_image) assert np.array_equal(expected_image, saved_image)
finally: finally:
writer.stop() writer.stop()
@ -202,7 +204,9 @@ def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory):
writer.wait_until_done() writer.wait_until_done()
assert fpath.exists() assert fpath.exists()
saved_image = np.array(Image.open(fpath)) saved_image = np.array(Image.open(fpath))
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(
np.uint8
)
assert np.array_equal(expected_image, saved_image) assert np.array_equal(expected_image, saved_image)
finally: finally:
writer.stop() writer.stop()
@ -292,7 +296,9 @@ def test_wait_until_done(tmp_path, img_array_factory):
writer = AsyncImageWriter(num_processes=0, num_threads=4) writer = AsyncImageWriter(num_processes=0, num_threads=4)
try: try:
num_images = 100 num_images = 100
image_arrays = [img_array_factory(height=500, width=500) for _ in range(num_images)] image_arrays = [
img_array_factory(height=500, width=500) for _ in range(num_images)
]
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)] fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
for image_array, fpath in zip(image_arrays, fpaths, strict=True): for image_array, fpath in zip(image_arrays, fpaths, strict=True):
fpath.parent.mkdir(parents=True, exist_ok=True) fpath.parent.mkdir(parents=True, exist_ok=True)

View File

@ -44,13 +44,23 @@ def make_new_buffer(
return buffer, write_dir return buffer, write_dir
def make_spoof_data_frames(n_episodes: int, n_frames_per_episode: int) -> dict[str, np.ndarray]: def make_spoof_data_frames(
n_episodes: int, n_frames_per_episode: int
) -> dict[str, np.ndarray]:
new_data = { new_data = {
data_key: np.arange(n_frames_per_episode * n_episodes * np.prod(data_shape)).reshape(-1, *data_shape), data_key: np.arange(
n_frames_per_episode * n_episodes * np.prod(data_shape)
).reshape(-1, *data_shape),
OnlineBuffer.INDEX_KEY: np.arange(n_frames_per_episode * n_episodes), OnlineBuffer.INDEX_KEY: np.arange(n_frames_per_episode * n_episodes),
OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(np.arange(n_episodes), n_frames_per_episode), OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(
OnlineBuffer.FRAME_INDEX_KEY: np.tile(np.arange(n_frames_per_episode), n_episodes), np.arange(n_episodes), n_frames_per_episode
OnlineBuffer.TIMESTAMP_KEY: np.tile(np.arange(n_frames_per_episode) / fps, n_episodes), ),
OnlineBuffer.FRAME_INDEX_KEY: np.tile(
np.arange(n_frames_per_episode), n_episodes
),
OnlineBuffer.TIMESTAMP_KEY: np.tile(
np.arange(n_frames_per_episode) / fps, n_episodes
),
} }
return new_data return new_data
@ -219,47 +229,72 @@ def test_compute_sampler_weights_trivial(
online_dataset_size: int, online_dataset_size: int,
online_sampling_ratio: float, online_sampling_ratio: float,
): ):
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=offline_dataset_size) offline_dataset = lerobot_dataset_factory(
tmp_path, total_episodes=1, total_frames=offline_dataset_size
)
online_dataset, _ = make_new_buffer() online_dataset, _ = make_new_buffer()
if online_dataset_size > 0: if online_dataset_size > 0:
online_dataset.add_data( online_dataset.add_data(
make_spoof_data_frames(n_episodes=2, n_frames_per_episode=online_dataset_size // 2) make_spoof_data_frames(
n_episodes=2, n_frames_per_episode=online_dataset_size // 2
)
) )
weights = compute_sampler_weights( weights = compute_sampler_weights(
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio offline_dataset,
online_dataset=online_dataset,
online_sampling_ratio=online_sampling_ratio,
) )
if offline_dataset_size == 0 or online_dataset_size == 0: if offline_dataset_size == 0 or online_dataset_size == 0:
expected_weights = torch.ones(offline_dataset_size + online_dataset_size) expected_weights = torch.ones(offline_dataset_size + online_dataset_size)
elif online_sampling_ratio == 0: elif online_sampling_ratio == 0:
expected_weights = torch.cat([torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)]) expected_weights = torch.cat(
[torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)]
)
elif online_sampling_ratio == 1: elif online_sampling_ratio == 1:
expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)]) expected_weights = torch.cat(
[torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)]
)
expected_weights /= expected_weights.sum() expected_weights /= expected_weights.sum()
torch.testing.assert_close(weights, expected_weights) torch.testing.assert_close(weights, expected_weights)
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path): def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path):
# Arbitrarily set small dataset sizes, making sure to have uneven sizes. # Arbitrarily set small dataset sizes, making sure to have uneven sizes.
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4) offline_dataset = lerobot_dataset_factory(
tmp_path, total_episodes=1, total_frames=4
)
online_dataset, _ = make_new_buffer() online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) online_dataset.add_data(
make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)
)
online_sampling_ratio = 0.8 online_sampling_ratio = 0.8
weights = compute_sampler_weights( weights = compute_sampler_weights(
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio offline_dataset,
online_dataset=online_dataset,
online_sampling_ratio=online_sampling_ratio,
) )
torch.testing.assert_close( torch.testing.assert_close(
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]) weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
) )
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_dataset_factory, tmp_path): def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(
lerobot_dataset_factory, tmp_path
):
# Arbitrarily set small dataset sizes, making sure to have uneven sizes. # Arbitrarily set small dataset sizes, making sure to have uneven sizes.
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4) offline_dataset = lerobot_dataset_factory(
tmp_path, total_episodes=1, total_frames=4
)
online_dataset, _ = make_new_buffer() online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) online_dataset.add_data(
make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)
)
weights = compute_sampler_weights( weights = compute_sampler_weights(
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1 offline_dataset,
online_dataset=online_dataset,
online_sampling_ratio=0.8,
online_drop_n_last_frames=1,
) )
torch.testing.assert_close( torch.testing.assert_close(
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0]) weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0])
@ -268,9 +303,13 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_datase
def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path): def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path):
"""Note: test copied from test_sampler.""" """Note: test copied from test_sampler."""
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=2) offline_dataset = lerobot_dataset_factory(
tmp_path, total_episodes=1, total_frames=2
)
online_dataset, _ = make_new_buffer() online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) online_dataset.add_data(
make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)
)
weights = compute_sampler_weights( weights = compute_sampler_weights(
offline_dataset, offline_dataset,

View File

@ -15,7 +15,9 @@
# limitations under the License. # limitations under the License.
from datasets import Dataset from datasets import Dataset
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
)
from lerobot.common.datasets.sampler import EpisodeAwareSampler from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
hf_transform_to_torch, hf_transform_to_torch,

View File

@ -20,17 +20,39 @@ DUMMY_MOTOR_FEATURES = {
"action": { "action": {
"dtype": "float32", "dtype": "float32",
"shape": (6,), "shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], "names": [
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
],
}, },
"state": { "state": {
"dtype": "float32", "dtype": "float32",
"shape": (6,), "shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], "names": [
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
],
}, },
} }
DUMMY_CAMERA_FEATURES = { DUMMY_CAMERA_FEATURES = {
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, "laptop": {
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, "shape": (480, 640, 3),
"names": ["height", "width", "channels"],
"info": None,
},
"phone": {
"shape": (480, 640, 3),
"names": ["height", "width", "channels"],
"info": None,
},
} }
DEFAULT_FPS = 30 DEFAULT_FPS = 30
DUMMY_VIDEO_INFO = { DUMMY_VIDEO_INFO = {

View File

@ -23,7 +23,11 @@ import PIL.Image
import pytest import pytest
import torch import torch
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata from lerobot.common.datasets.lerobot_dataset import (
CODEBASE_VERSION,
LeRobotDataset,
LeRobotDatasetMetadata,
)
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_SIZE,
DEFAULT_FEATURES, DEFAULT_FEATURES,
@ -54,7 +58,9 @@ def get_task_index(task_dicts: dict, task: str) -> int:
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def img_tensor_factory(): def img_tensor_factory():
def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor: def _create_img_tensor(
height=100, width=100, channels=3, dtype=torch.float32
) -> torch.Tensor:
return torch.rand((channels, height, width), dtype=dtype) return torch.rand((channels, height, width), dtype=dtype)
return _create_img_tensor return _create_img_tensor
@ -62,10 +68,14 @@ def img_tensor_factory():
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def img_array_factory(): def img_array_factory():
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray: def _create_img_array(
height=100, width=100, channels=3, dtype=np.uint8
) -> np.ndarray:
if np.issubdtype(dtype, np.unsignedinteger): if np.issubdtype(dtype, np.unsignedinteger):
# Int array in [0, 255] range # Int array in [0, 255] range
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype) img_array = np.random.randint(
0, 256, size=(height, width, channels), dtype=dtype
)
elif np.issubdtype(dtype, np.floating): elif np.issubdtype(dtype, np.floating):
# Float array in [0, 1] range # Float array in [0, 1] range
img_array = np.random.rand(height, width, channels).astype(dtype) img_array = np.random.rand(height, width, channels).astype(dtype)
@ -94,10 +104,13 @@ def features_factory():
) -> dict: ) -> dict:
if use_videos: if use_videos:
camera_ft = { camera_ft = {
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items() key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO}
for key, ft in camera_features.items()
} }
else: else:
camera_ft = {key: {"dtype": "image", **ft} for key, ft in camera_features.items()} camera_ft = {
key: {"dtype": "image", **ft} for key, ft in camera_features.items()
}
return { return {
**motor_features, **motor_features,
**camera_ft, **camera_ft,
@ -215,7 +228,9 @@ def episodes_factory(tasks_factory):
if total_episodes <= 0 or total_frames <= 0: if total_episodes <= 0 or total_frames <= 0:
raise ValueError("num_episodes and total_length must be positive integers.") raise ValueError("num_episodes and total_length must be positive integers.")
if total_frames < total_episodes: if total_frames < total_episodes:
raise ValueError("total_length must be greater than or equal to num_episodes.") raise ValueError(
"total_length must be greater than or equal to num_episodes."
)
if not tasks: if not tasks:
min_tasks = 2 if multi_task else 1 min_tasks = 2 if multi_task else 1
@ -223,10 +238,14 @@ def episodes_factory(tasks_factory):
tasks = tasks_factory(total_tasks) tasks = tasks_factory(total_tasks)
if total_episodes < len(tasks) and not multi_task: if total_episodes < len(tasks) and not multi_task:
raise ValueError("The number of tasks should be less than the number of episodes.") raise ValueError(
"The number of tasks should be less than the number of episodes."
)
# Generate random lengths that sum up to total_length # Generate random lengths that sum up to total_length
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist() lengths = np.random.multinomial(
total_frames, [1 / total_episodes] * total_episodes
).tolist()
tasks_list = [task_dict["task"] for task_dict in tasks.values()] tasks_list = [task_dict["task"] for task_dict in tasks.values()]
num_tasks_available = len(tasks_list) num_tasks_available = len(tasks_list)
@ -234,9 +253,13 @@ def episodes_factory(tasks_factory):
episodes = {} episodes = {}
remaining_tasks = tasks_list.copy() remaining_tasks = tasks_list.copy()
for ep_idx in range(total_episodes): for ep_idx in range(total_episodes):
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1 num_tasks_in_episode = (
random.randint(1, min(3, num_tasks_available)) if multi_task else 1
)
tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample))) episode_tasks = random.sample(
tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample))
)
if remaining_tasks: if remaining_tasks:
for task in episode_tasks: for task in episode_tasks:
remaining_tasks.remove(task) remaining_tasks.remove(task)
@ -253,7 +276,9 @@ def episodes_factory(tasks_factory):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory): def hf_dataset_factory(
features_factory, tasks_factory, episodes_factory, img_array_factory
):
def _create_hf_dataset( def _create_hf_dataset(
features: dict | None = None, features: dict | None = None,
tasks: list[dict] | None = None, tasks: list[dict] | None = None,
@ -275,10 +300,15 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps)) timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int))) frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
episode_index_col = np.concatenate( episode_index_col = np.concatenate(
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int)) (
episode_index_col,
np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int),
)
) )
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0]) ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int))) task_index = np.concatenate(
(task_index, np.full(ep_dict["length"], ep_task_index, dtype=int))
)
index_col = np.arange(len(episode_index_col)) index_col = np.arange(len(episode_index_col))
@ -290,7 +320,9 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
for _ in range(len(index_col)) for _ in range(len(index_col))
] ]
elif ft["shape"][0] > 1 and ft["dtype"] != "video": elif ft["shape"][0] > 1 and ft["dtype"] != "video":
robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"]) robot_cols[key] = np.random.random(
(len(index_col), ft["shape"][0])
).astype(ft["dtype"])
hf_features = get_hf_features_from_features(features) hf_features = get_hf_features_from_features(features)
dataset = datasets.Dataset.from_dict( dataset = datasets.Dataset.from_dict(
@ -340,7 +372,9 @@ def lerobot_dataset_metadata_factory(
tasks = tasks_factory(total_tasks=info["total_tasks"]) tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes: if not episodes:
episodes = episodes_factory( episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
tasks=tasks,
) )
mock_snapshot_download = mock_snapshot_download_factory( mock_snapshot_download = mock_snapshot_download_factory(
@ -392,7 +426,9 @@ def lerobot_dataset_factory(
) -> LeRobotDataset: ) -> LeRobotDataset:
if not info: if not info:
info = info_factory( info = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks total_episodes=total_episodes,
total_frames=total_frames,
total_tasks=total_tasks,
) )
if not stats: if not stats:
stats = stats_factory(features=info["features"]) stats = stats_factory(features=info["features"])
@ -408,7 +444,9 @@ def lerobot_dataset_factory(
multi_task=multi_task, multi_task=multi_task,
) )
if not hf_dataset: if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"]) hf_dataset = hf_dataset_factory(
tasks=tasks, episodes=episode_dicts, fps=info["fps"]
)
mock_snapshot_download = mock_snapshot_download_factory( mock_snapshot_download = mock_snapshot_download_factory(
info=info, info=info,

View File

@ -102,7 +102,10 @@ def episode_path(episodes_factory):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def single_episode_parquet_path(hf_dataset_factory, info_factory): def single_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_single_episode_parquet( def _create_single_episode_parquet(
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None dir: Path,
ep_idx: int = 0,
hf_dataset: datasets.Dataset | None = None,
info: dict | None = None,
) -> Path: ) -> Path:
if not info: if not info:
info = info_factory() info = info_factory()

24
tests/fixtures/hub.py vendored
View File

@ -67,15 +67,21 @@ def mock_snapshot_download_factory(
tasks = tasks_factory(total_tasks=info["total_tasks"]) tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes: if not episodes:
episodes = episodes_factory( episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
tasks=tasks,
) )
if not hf_dataset: if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"]) hf_dataset = hf_dataset_factory(
tasks=tasks, episodes=episodes, fps=info["fps"]
)
def _extract_episode_index_from_path(fpath: str) -> int: def _extract_episode_index_from_path(fpath: str) -> int:
path = Path(fpath) path = Path(fpath)
if path.suffix == ".parquet" and path.stem.startswith("episode_"): if path.suffix == ".parquet" and path.stem.startswith("episode_"):
episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0 episode_index = int(
path.stem[len("episode_") :]
) # 'episode_000000' -> 0
return episode_index return episode_index
else: else:
return None return None
@ -100,12 +106,16 @@ def mock_snapshot_download_factory(
for episode_dict in episodes.values(): for episode_dict in episodes.values():
ep_idx = episode_dict["episode_index"] ep_idx = episode_dict["episode_index"]
ep_chunk = ep_idx // info["chunks_size"] ep_chunk = ep_idx // info["chunks_size"]
data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx) data_path = info["data_path"].format(
episode_chunk=ep_chunk, episode_index=ep_idx
)
data_files.append(data_path) data_files.append(data_path)
all_files.extend(data_files) all_files.extend(data_files)
allowed_files = filter_repo_objects( allowed_files = filter_repo_objects(
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns all_files,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
) )
# Create allowed files # Create allowed files
@ -113,7 +123,9 @@ def mock_snapshot_download_factory(
if rel_path.startswith("data/"): if rel_path.startswith("data/"):
episode_index = _extract_episode_index_from_path(rel_path) episode_index = _extract_episode_index_from_path(rel_path)
if episode_index is not None: if episode_index is not None:
_ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info) _ = single_episode_parquet_path(
local_dir, episode_index, hf_dataset, info
)
if rel_path == INFO_PATH: if rel_path == INFO_PATH:
_ = info_path(local_dir, info) _ = info_path(local_dir, info)
elif rel_path == STATS_PATH: elif rel_path == STATS_PATH:

View File

@ -80,7 +80,9 @@ class GroupSyncRead:
def addParam(self, motor_index): # noqa: N802 def addParam(self, motor_index): # noqa: N802
# Initialize motor default values # Initialize motor default values
if motor_index not in self.packet_handler.data: if motor_index not in self.packet_handler.data:
self.packet_handler.data[motor_index] = get_default_motor_values(motor_index) self.packet_handler.data[motor_index] = get_default_motor_values(
motor_index
)
def txRxPacket(self): # noqa: N802 def txRxPacket(self): # noqa: N802
return COMM_SUCCESS return COMM_SUCCESS

View File

@ -91,7 +91,9 @@ class GroupSyncRead:
def addParam(self, motor_index): # noqa: N802 def addParam(self, motor_index): # noqa: N802
# Initialize motor default values # Initialize motor default values
if motor_index not in self.packet_handler.data: if motor_index not in self.packet_handler.data:
self.packet_handler.data[motor_index] = get_default_motor_values(motor_index) self.packet_handler.data[motor_index] = get_default_motor_values(
motor_index
)
def txRxPacket(self): # noqa: N802 def txRxPacket(self): # noqa: N802
return COMM_SUCCESS return COMM_SUCCESS

View File

@ -43,7 +43,10 @@ import time
import numpy as np import numpy as np
import pytest import pytest
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
)
from lerobot.scripts.find_motors_bus_port import find_port from lerobot.scripts.find_motors_bus_port import find_port
from tests.utils import TEST_MOTOR_TYPES, make_motors_bus, require_motor from tests.utils import TEST_MOTOR_TYPES, make_motors_bus, require_motor
@ -76,7 +79,9 @@ def test_configure_motors_all_ids_1(request, motor_type, mock):
else: else:
raise ValueError(motor_type) raise ValueError(motor_type)
input("Are you sure you want to re-configure the motors? Press enter to continue...") input(
"Are you sure you want to re-configure the motors? Press enter to continue..."
)
# This test expect the configuration was already correct. # This test expect the configuration was already correct.
motors_bus = make_motors_bus(motor_type, mock=mock) motors_bus = make_motors_bus(motor_type, mock=mock)
motors_bus.connect() motors_bus.connect()

View File

@ -25,7 +25,10 @@ from torchmetrics import AUROC, Accuracy, F1Score, Precision, Recall
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor from torchvision.transforms import ToTensor
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier, ClassifierConfig from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
ClassifierConfig,
)
BATCH_SIZE = 1000 BATCH_SIZE = 1000
LR = 0.1 LR = 0.1
@ -43,7 +46,9 @@ def train_evaluate_multiclass_classifier():
logging.info( logging.info(
f"Start multiclass classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}" f"Start multiclass classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}"
) )
multiclass_config = ClassifierConfig(model_name="microsoft/resnet-18", device=DEVICE, num_classes=10) multiclass_config = ClassifierConfig(
model_name="microsoft/resnet-18", device=DEVICE, num_classes=10
)
multiclass_classifier = Classifier(multiclass_config) multiclass_classifier = Classifier(multiclass_config)
trainset = CIFAR10(root="data", train=True, download=True, transform=ToTensor()) trainset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
@ -114,10 +119,18 @@ def train_evaluate_multiclass_classifier():
test_probs = torch.stack(test_probs) test_probs = torch.stack(test_probs)
accuracy = Accuracy(task="multiclass", num_classes=multiclass_num_classes) accuracy = Accuracy(task="multiclass", num_classes=multiclass_num_classes)
precision = Precision(task="multiclass", average="weighted", num_classes=multiclass_num_classes) precision = Precision(
recall = Recall(task="multiclass", average="weighted", num_classes=multiclass_num_classes) task="multiclass", average="weighted", num_classes=multiclass_num_classes
f1 = F1Score(task="multiclass", average="weighted", num_classes=multiclass_num_classes) )
auroc = AUROC(task="multiclass", num_classes=multiclass_num_classes, average="weighted") recall = Recall(
task="multiclass", average="weighted", num_classes=multiclass_num_classes
)
f1 = F1Score(
task="multiclass", average="weighted", num_classes=multiclass_num_classes
)
auroc = AUROC(
task="multiclass", num_classes=multiclass_num_classes, average="weighted"
)
# Calculate metrics # Calculate metrics
acc = accuracy(test_predictions, test_labels) acc = accuracy(test_predictions, test_labels)
@ -146,18 +159,28 @@ def train_evaluate_binary_classifier():
new_label = float(1.0) if label == target_class else float(0.0) new_label = float(1.0) if label == target_class else float(0.0)
new_targets.append(new_label) new_targets.append(new_label)
dataset.targets = new_targets # Replace the original labels with the binary ones dataset.targets = (
new_targets # Replace the original labels with the binary ones
)
return dataset return dataset
binary_train_dataset = CIFAR10(root="data", train=True, download=True, transform=ToTensor()) binary_train_dataset = CIFAR10(
binary_test_dataset = CIFAR10(root="data", train=False, download=True, transform=ToTensor()) root="data", train=True, download=True, transform=ToTensor()
)
binary_test_dataset = CIFAR10(
root="data", train=False, download=True, transform=ToTensor()
)
# Apply one-vs-rest labeling # Apply one-vs-rest labeling
binary_train_dataset = one_vs_rest(binary_train_dataset, target_binary_class) binary_train_dataset = one_vs_rest(binary_train_dataset, target_binary_class)
binary_test_dataset = one_vs_rest(binary_test_dataset, target_binary_class) binary_test_dataset = one_vs_rest(binary_test_dataset, target_binary_class)
binary_trainloader = DataLoader(binary_train_dataset, batch_size=BATCH_SIZE, shuffle=True) binary_trainloader = DataLoader(
binary_testloader = DataLoader(binary_test_dataset, batch_size=BATCH_SIZE, shuffle=False) binary_train_dataset, batch_size=BATCH_SIZE, shuffle=True
)
binary_testloader = DataLoader(
binary_test_dataset, batch_size=BATCH_SIZE, shuffle=False
)
binary_epoch = 1 binary_epoch = 1

View File

@ -9,7 +9,9 @@ from tests.utils import require_package
def test_classifier_output(): def test_classifier_output():
output = ClassifierOutput( output = ClassifierOutput(
logits=torch.tensor([1, 2, 3]), probabilities=torch.tensor([0.1, 0.2, 0.3]), hidden_states=None logits=torch.tensor([1, 2, 3]),
probabilities=torch.tensor([0.1, 0.2, 0.3]),
hidden_states=None,
) )
assert ( assert (
@ -20,7 +22,9 @@ def test_classifier_output():
@require_package("transformers") @require_package("transformers")
def test_binary_classifier_with_default_params(): def test_binary_classifier_with_default_params():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
)
config = ClassifierConfig() config = ClassifierConfig()
classifier = Classifier(config) classifier = Classifier(config)
@ -41,7 +45,9 @@ def test_binary_classifier_with_default_params():
@require_package("transformers") @require_package("transformers")
def test_multiclass_classifier(): def test_multiclass_classifier():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
)
num_classes = 5 num_classes = 5
config = ClassifierConfig(num_classes=num_classes) config = ClassifierConfig(num_classes=num_classes)
@ -63,7 +69,9 @@ def test_multiclass_classifier():
@require_package("transformers") @require_package("transformers")
def test_default_device(): def test_default_device():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
)
config = ClassifierConfig() config = ClassifierConfig()
assert config.device == "cpu" assert config.device == "cpu"
@ -75,7 +83,9 @@ def test_default_device():
@require_package("transformers") @require_package("transformers")
def test_explicit_device_setup(): def test_explicit_device_setup():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
)
config = ClassifierConfig(device="meta") config = ClassifierConfig(device="meta")
assert config.device == "meta" assert config.device == "meta"

View File

@ -172,7 +172,9 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
# Test updating the policy (and test that it does not mutate the batch) # Test updating the policy (and test that it does not mutate the batch)
batch_ = deepcopy(batch) batch_ = deepcopy(batch)
policy.forward(batch) policy.forward(batch)
assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass." assert set(batch) == set(
batch_
), "Batch keys are not the same after a forward pass."
assert all( assert all(
torch.equal(batch[k], batch_[k]) if isinstance(batch[k], torch.Tensor) else batch[k] == batch_[k] torch.equal(batch[k], batch_[k]) if isinstance(batch[k], torch.Tensor) else batch[k] == batch_[k]
for k in batch for k in batch
@ -186,7 +188,9 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
observation = preprocess_observation(observation) observation = preprocess_observation(observation)
# send observation to device/gpu # send observation to device/gpu
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation} observation = {
key: observation[key].to(DEVICE, non_blocking=True) for key in observation
}
# get the next action for the environment (also check that the observation batch is not modified) # get the next action for the environment (also check that the observation batch is not modified)
observation_ = deepcopy(observation) observation_ = deepcopy(observation)
@ -452,7 +456,9 @@ def test_act_temporal_ensembler():
batch_size = batch_seq.shape[0] batch_size = batch_seq.shape[0]
# Exponential weighting (normalized). Unsqueeze once to match the position of the `episode_length` # Exponential weighting (normalized). Unsqueeze once to match the position of the `episode_length`
# dimension of `batch_seq`. # dimension of `batch_seq`.
weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(-1) weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(
-1
)
# Simulate stepping through a rollout and computing a batch of actions with model on each step. # Simulate stepping through a rollout and computing a batch of actions with model on each step.
for i in range(episode_length): for i in range(episode_length):
@ -475,7 +481,8 @@ def test_act_temporal_ensembler():
episode_step_indices = torch.arange(i + 1)[-len(chunk_indices) :] episode_step_indices = torch.arange(i + 1)[-len(chunk_indices) :]
seq_slice = batch_seq[:, episode_step_indices, chunk_indices] seq_slice = batch_seq[:, episode_step_indices, chunk_indices]
offline_avg = ( offline_avg = (
einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum") / weights[: i + 1].sum() einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum")
/ weights[: i + 1].sum()
) )
# Sanity check. The average should be between the extrema. # Sanity check. The average should be between the extrema.
assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg) assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg)

View File

@ -335,8 +335,12 @@ def test_record_with_event_rerecord_episode(tmp_path, request, robot_type, mock)
) )
dataset = record(robot, rec_cfg) dataset = record(robot, rec_cfg)
assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False" assert not mock_events[
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False" "rerecord_episode"
], "`rerecord_episode` wasn't properly reset to False"
assert not mock_events[
"exit_early"
], "`exit_early` wasn't properly reset to False"
assert len(dataset) == 1, "`dataset` should contain only 1 frame" assert len(dataset) == 1, "`dataset` should contain only 1 frame"
@ -391,7 +395,8 @@ def test_record_with_event_exit_early(tmp_path, request, robot_type, mock):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"robot_type, mock, num_image_writer_processes", [("koch", True, 0), ("koch", True, 1)] "robot_type, mock, num_image_writer_processes",
[("koch", True, 0), ("koch", True, 1)],
) )
@require_robot @require_robot
def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, num_image_writer_processes): def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, num_image_writer_processes):

View File

@ -105,7 +105,9 @@ def test_robot(tmp_path, request, robot_type, mock):
assert "observation.state" in observation assert "observation.state" in observation
assert isinstance(observation["observation.state"], torch.Tensor) assert isinstance(observation["observation.state"], torch.Tensor)
assert observation["observation.state"].ndim == 1 assert observation["observation.state"].ndim == 1
dim_state = sum(len(robot.follower_arms[name].motors) for name in robot.follower_arms) dim_state = sum(
len(robot.follower_arms[name].motors) for name in robot.follower_arms
)
assert observation["observation.state"].shape[0] == dim_state assert observation["observation.state"].shape[0] == dim_state
# Cameras # Cameras
for name in robot.cameras: for name in robot.cameras:
@ -116,7 +118,9 @@ def test_robot(tmp_path, request, robot_type, mock):
assert "action" in action assert "action" in action
assert isinstance(action["action"], torch.Tensor) assert isinstance(action["action"], torch.Tensor)
assert action["action"].ndim == 1 assert action["action"].ndim == 1
dim_action = sum(len(robot.follower_arms[name].motors) for name in robot.follower_arms) dim_action = sum(
len(robot.follower_arms[name].motors) for name in robot.follower_arms
)
assert action["action"].shape[0] == dim_action assert action["action"].shape[0] == dim_action
# TODO(rcadene): test if observation and action data are returned as expected # TODO(rcadene): test if observation and action data are returned as expected

View File

@ -9,7 +9,9 @@ from hydra import compose, initialize_config_dir
from torch import nn from torch import nn
from torch.utils.data import Dataset from torch.utils.data import Dataset
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
ClassifierConfig,
)
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.scripts.train_hilserl_classifier import ( from lerobot.scripts.train_hilserl_classifier import (
create_balanced_sampler, create_balanced_sampler,
@ -34,7 +36,9 @@ class MockDataset(Dataset):
def make_dummy_model(): def make_dummy_model():
model_config = ClassifierConfig( model_config = ClassifierConfig(
num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=1 num_classes=2,
model_name="hf-tiny-model-private/tiny-random-ResNetModel",
num_cameras=1,
) )
model = Classifier(config=model_config) model = Classifier(config=model_config)
return model return model
@ -65,7 +69,9 @@ def test_create_balanced_sampler():
labels = [item["label"] for item in data] labels = [item["label"] for item in data]
class_counts = torch.tensor([labels.count(0), labels.count(1)], dtype=torch.float32) class_counts = torch.tensor([labels.count(0), labels.count(1)], dtype=torch.float32)
class_weights = 1.0 / class_counts class_weights = 1.0 / class_counts
expected_weights = torch.tensor([class_weights[label] for label in labels], dtype=torch.float32) expected_weights = torch.tensor(
[class_weights[label] for label in labels], dtype=torch.float32
)
# Test that the weights are correct # Test that the weights are correct
assert torch.allclose(weights, expected_weights) assert torch.allclose(weights, expected_weights)
@ -149,7 +155,9 @@ def test_validate():
def test_train_epoch_multiple_cameras(): def test_train_epoch_multiple_cameras():
model_config = ClassifierConfig( model_config = ClassifierConfig(
num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=2 num_classes=2,
model_name="hf-tiny-model-private/tiny-random-ResNetModel",
num_cameras=2,
) )
model = Classifier(config=model_config) model = Classifier(config=model_config)
@ -216,10 +224,16 @@ def test_resume_function(
): ):
# Initialize Hydra # Initialize Hydra
test_file_dir = os.path.dirname(os.path.abspath(__file__)) test_file_dir = os.path.dirname(os.path.abspath(__file__))
config_dir = os.path.abspath(os.path.join(test_file_dir, "..", "lerobot", "configs", "policy")) config_dir = os.path.abspath(
assert os.path.exists(config_dir), f"Config directory does not exist at {config_dir}" os.path.join(test_file_dir, "..", "lerobot", "configs", "policy")
)
assert os.path.exists(
config_dir
), f"Config directory does not exist at {config_dir}"
with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"): with initialize_config_dir(
config_dir=config_dir, job_name="test_app", version_base="1.2"
):
cfg = compose( cfg = compose(
config_name="hilserl_classifier", config_name="hilserl_classifier",
overrides=[ overrides=[
@ -244,7 +258,9 @@ def test_resume_function(
mock_init_hydra_config.return_value = cfg mock_init_hydra_config.return_value = cfg
# Mock dataset # Mock dataset
dataset = MockDataset([{"image": torch.rand(3, 224, 224), "label": i % 2} for i in range(10)]) dataset = MockDataset(
[{"image": torch.rand(3, 224, 224), "label": i % 2} for i in range(10)]
)
mock_dataset.return_value = dataset mock_dataset.return_value = dataset
# Mock checkpoint handling # Mock checkpoint handling

View File

@ -47,7 +47,9 @@ for motor_type in available_motors:
OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0)) OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0))
INTELREALSENSE_SERIAL_NUMBER = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614)) INTELREALSENSE_SERIAL_NUMBER = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614))
DYNAMIXEL_PORT = os.environ.get("LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081") DYNAMIXEL_PORT = os.environ.get(
"LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081"
)
DYNAMIXEL_MOTORS = { DYNAMIXEL_MOTORS = {
"shoulder_pan": [1, "xl430-w250"], "shoulder_pan": [1, "xl430-w250"],
"shoulder_lift": [2, "xl430-w250"], "shoulder_lift": [2, "xl430-w250"],
@ -57,7 +59,9 @@ DYNAMIXEL_MOTORS = {
"gripper": [6, "xl330-m288"], "gripper": [6, "xl330-m288"],
} }
FEETECH_PORT = os.environ.get("LEROBOT_TEST_FEETECH_PORT", "/dev/tty.usbmodem585A0080971") FEETECH_PORT = os.environ.get(
"LEROBOT_TEST_FEETECH_PORT", "/dev/tty.usbmodem585A0080971"
)
FEETECH_MOTORS = { FEETECH_MOTORS = {
"shoulder_pan": [1, "sts3215"], "shoulder_pan": [1, "sts3215"],
"shoulder_lift": [2, "sts3215"], "shoulder_lift": [2, "sts3215"],
@ -156,9 +160,13 @@ def require_package_arg(func):
if "required_packages" in arg_names: if "required_packages" in arg_names:
# Get the index of 'required_packages' and retrieve the value from args # Get the index of 'required_packages' and retrieve the value from args
index = arg_names.index("required_packages") index = arg_names.index("required_packages")
required_packages = args[index] if len(args) > index else kwargs.get("required_packages") required_packages = (
args[index] if len(args) > index else kwargs.get("required_packages")
)
else: else:
raise ValueError("Function does not have 'required_packages' as an argument.") raise ValueError(
"Function does not have 'required_packages' as an argument."
)
if required_packages is None: if required_packages is None:
return func(*args, **kwargs) return func(*args, **kwargs)
@ -215,11 +223,17 @@ def require_robot(func):
mock = kwargs.get("mock") mock = kwargs.get("mock")
if robot_type is None: if robot_type is None:
raise ValueError("The 'robot_type' must be an argument of the test function.") raise ValueError(
"The 'robot_type' must be an argument of the test function."
)
if request is None: if request is None:
raise ValueError("The 'request' fixture must be an argument of the test function.") raise ValueError(
"The 'request' fixture must be an argument of the test function."
)
if mock is None: if mock is None:
raise ValueError("The 'mock' variable must be an argument of the test function.") raise ValueError(
"The 'mock' variable must be an argument of the test function."
)
# Run test with a real robot. Skip test if robot connection fails. # Run test with a real robot. Skip test if robot connection fails.
if not mock and not request.getfixturevalue("is_robot_available"): if not mock and not request.getfixturevalue("is_robot_available"):
@ -239,11 +253,17 @@ def require_camera(func):
mock = kwargs.get("mock") mock = kwargs.get("mock")
if request is None: if request is None:
raise ValueError("The 'request' fixture must be an argument of the test function.") raise ValueError(
"The 'request' fixture must be an argument of the test function."
)
if camera_type is None: if camera_type is None:
raise ValueError("The 'camera_type' must be an argument of the test function.") raise ValueError(
"The 'camera_type' must be an argument of the test function."
)
if mock is None: if mock is None:
raise ValueError("The 'mock' variable must be an argument of the test function.") raise ValueError(
"The 'mock' variable must be an argument of the test function."
)
if not mock and not request.getfixturevalue("is_camera_available"): if not mock and not request.getfixturevalue("is_camera_available"):
pytest.skip(f"A {camera_type} camera is not available.") pytest.skip(f"A {camera_type} camera is not available.")
@ -262,11 +282,17 @@ def require_motor(func):
mock = kwargs.get("mock") mock = kwargs.get("mock")
if request is None: if request is None:
raise ValueError("The 'request' fixture must be an argument of the test function.") raise ValueError(
"The 'request' fixture must be an argument of the test function."
)
if motor_type is None: if motor_type is None:
raise ValueError("The 'motor_type' must be an argument of the test function.") raise ValueError(
"The 'motor_type' must be an argument of the test function."
)
if mock is None: if mock is None:
raise ValueError("The 'mock' variable must be an argument of the test function.") raise ValueError(
"The 'mock' variable must be an argument of the test function."
)
if not mock and not request.getfixturevalue("is_motor_available"): if not mock and not request.getfixturevalue("is_motor_available"):
pytest.skip(f"A {motor_type} motor is not available.") pytest.skip(f"A {motor_type} motor is not available.")
@ -285,7 +311,14 @@ def mock_calibration_dir(calibration_dir):
"start_pos": [1442, 843, 2166, 2849, 1988, 1835], "start_pos": [1442, 843, 2166, 2849, 1988, 1835],
"end_pos": [2440, 1869, -1106, -1848, -926, 3235], "end_pos": [2440, 1869, -1106, -1848, -926, 3235],
"calib_mode": ["DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "LINEAR"], "calib_mode": ["DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "LINEAR"],
"motor_names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], "motor_names": [
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
],
} }
Path(str(calibration_dir)).mkdir(parents=True, exist_ok=True) Path(str(calibration_dir)).mkdir(parents=True, exist_ok=True)
with open(calibration_dir / "main_follower.json", "w") as f: with open(calibration_dir / "main_follower.json", "w") as f: