diff --git a/benchmarks/video/run_video_benchmark.py b/benchmarks/video/run_video_benchmark.py index c62578c4..64240282 100644 --- a/benchmarks/video/run_video_benchmark.py +++ b/benchmarks/video/run_video_benchmark.py @@ -32,7 +32,11 @@ import numpy as np import pandas as pd import PIL import torch -from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity +from skimage.metrics import ( + mean_squared_error, + peak_signal_noise_ratio, + structural_similarity, +) from tqdm import tqdm from lerobot.common.datasets.lerobot_dataset import LeRobotDataset @@ -81,7 +85,9 @@ def get_directory_size(directory: Path) -> int: return total_size -def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> torch.Tensor: +def load_original_frames( + imgs_dir: Path, timestamps: list[float], fps: int +) -> torch.Tensor: frames = [] for ts in timestamps: idx = int(ts * fps) @@ -94,7 +100,11 @@ def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> t def save_decoded_frames( - imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int + imgs_dir: Path, + save_dir: Path, + frames: torch.Tensor, + timestamps: list[float], + fps: int, ) -> None: if save_dir.exists() and len(list(save_dir.glob("frame_*.png"))) == len(timestamps): return @@ -104,7 +114,10 @@ def save_decoded_frames( idx = int(ts * fps) frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy() PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame_{idx:06d}_decoded.png") - shutil.copyfile(imgs_dir / f"frame_{idx:06d}.png", save_dir / f"frame_{idx:06d}_original.png") + shutil.copyfile( + imgs_dir / f"frame_{idx:06d}.png", + save_dir / f"frame_{idx:06d}_original.png", + ) def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None: @@ -116,11 +129,17 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None: hf_dataset = dataset.hf_dataset.with_format(None) # We only save images from the first camera - img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")] + img_keys = [ + key for key in hf_dataset.features if key.startswith("observation.image") + ] imgs_dataset = hf_dataset.select_columns(img_keys[0]) for i, item in enumerate( - tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False) + tqdm( + imgs_dataset, + desc=f"saving {dataset.repo_id} first episode images", + leave=False, + ) ): img = item[img_keys[0]] img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100) @@ -129,7 +148,9 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None: break -def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> list[float]: +def sample_timestamps( + timestamps_mode: str, ep_num_images: int, fps: int +) -> list[float]: # Start at 5 to allow for 2_frames_4_space and 6_frames idx = random.randint(5, ep_num_images - 1) match timestamps_mode: @@ -154,7 +175,9 @@ def decode_video_frames( backend: str, ) -> torch.Tensor: if backend in ["pyav", "video_reader"]: - return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend) + return decode_video_frames_torchvision( + video_path, timestamps, tolerance_s, backend + ) else: raise NotImplementedError(backend) @@ -181,7 +204,9 @@ def benchmark_decoding( } with time_benchmark: - frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend) + frames = decode_video_frames( + video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend + ) result["load_time_video_ms"] = time_benchmark.result_ms / num_frames with time_benchmark: @@ -190,12 +215,18 @@ def benchmark_decoding( frames_np, original_frames_np = frames.numpy(), original_frames.numpy() for i in range(num_frames): - result["mse_values"].append(mean_squared_error(original_frames_np[i], frames_np[i])) + result["mse_values"].append( + mean_squared_error(original_frames_np[i], frames_np[i]) + ) result["psnr_values"].append( - peak_signal_noise_ratio(original_frames_np[i], frames_np[i], data_range=1.0) + peak_signal_noise_ratio( + original_frames_np[i], frames_np[i], data_range=1.0 + ) ) result["ssim_values"].append( - structural_similarity(original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0) + structural_similarity( + original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0 + ) ) if save_frames and sample == 0: @@ -215,7 +246,9 @@ def benchmark_decoding( # As these samples are independent, we run them in parallel threads to speed up the benchmark. with ThreadPoolExecutor(max_workers=num_workers) as executor: futures = [executor.submit(process_sample, i) for i in range(num_samples)] - for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False): + for future in tqdm( + as_completed(futures), total=num_samples, desc="samples", leave=False + ): result = future.result() load_times_video_ms.append(result["load_time_video_ms"]) load_times_images_ms.append(result["load_time_images_ms"]) @@ -275,9 +308,13 @@ def benchmark_encoding_decoding( random.seed(seed) benchmark_table = [] for timestamps_mode in tqdm( - decoding_cfg["timestamps_modes"], desc="decodings (timestamps_modes)", leave=False + decoding_cfg["timestamps_modes"], + desc="decodings (timestamps_modes)", + leave=False, ): - for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False): + for backend in tqdm( + decoding_cfg["backends"], desc="decodings (backends)", leave=False + ): benchmark_row = benchmark_decoding( imgs_dir, video_path, @@ -355,14 +392,23 @@ def main( imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_") # We only use the first episode save_first_episode(imgs_dir, dataset) - for key, values in tqdm(encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False): + for key, values in tqdm( + encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False + ): for value in tqdm(values, desc=f"encodings ({key})", leave=False): encoding_cfg = BASE_ENCODING.copy() encoding_cfg["vcodec"] = video_codec encoding_cfg["pix_fmt"] = pixel_format encoding_cfg[key] = value - args_path = Path("_".join(str(value) for value in encoding_cfg.values())) - video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4" + args_path = Path( + "_".join(str(value) for value in encoding_cfg.values()) + ) + video_path = ( + output_dir + / "videos" + / args_path + / f"{repo_id.replace('/', '_')}.mp4" + ) benchmark_table += benchmark_encoding_decoding( dataset, video_path, @@ -388,7 +434,9 @@ def main( # Concatenate all results df_list = [pd.read_csv(csv_path) for csv_path in file_paths] concatenated_df = pd.concat(df_list, ignore_index=True) - concatenated_path = output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv" + concatenated_path = ( + output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv" + ) concatenated_df.to_csv(concatenated_path, header=True, index=False) diff --git a/examples/1_load_lerobot_dataset.py b/examples/1_load_lerobot_dataset.py index c374a375..7b0b9846 100644 --- a/examples/1_load_lerobot_dataset.py +++ b/examples/1_load_lerobot_dataset.py @@ -32,7 +32,10 @@ import torch from huggingface_hub import HfApi import lerobot -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +from lerobot.common.datasets.lerobot_dataset import ( + LeRobotDataset, + LeRobotDatasetMetadata, +) # We ported a number of existing datasets ourselves, use this to see the list: print("List of available datasets:") @@ -40,7 +43,10 @@ pprint(lerobot.available_datasets) # You can also browse through the datasets created/ported by the community on the hub using the hub api: hub_api = HfApi() -repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])] +repo_ids = [ + info.id + for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"]) +] pprint(repo_ids) # Or simply explore them in your web browser directly at: @@ -55,7 +61,9 @@ ds_meta = LeRobotDatasetMetadata(repo_id) # structure of the dataset without downloading the actual data yet (only metadata files — which are # lightweight). print(f"Total number of episodes: {ds_meta.total_episodes}") -print(f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}") +print( + f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}" +) print(f"Frames per second used during data collection: {ds_meta.fps}") print(f"Robot type: {ds_meta.robot_type}") print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n") diff --git a/examples/advanced/1_add_image_transforms.py b/examples/advanced/1_add_image_transforms.py index f1460926..78dc6152 100644 --- a/examples/advanced/1_add_image_transforms.py +++ b/examples/advanced/1_add_image_transforms.py @@ -48,10 +48,14 @@ transforms = v2.Compose( ) # Create another LeRobotDataset with the defined transformations -transformed_dataset = LeRobotDataset(dataset_repo_id, episodes=[0], image_transforms=transforms) +transformed_dataset = LeRobotDataset( + dataset_repo_id, episodes=[0], image_transforms=transforms +) # Get a frame from the transformed dataset -transformed_frame = transformed_dataset[first_idx][transformed_dataset.meta.camera_keys[0]] +transformed_frame = transformed_dataset[first_idx][ + transformed_dataset.meta.camera_keys[0] +] # Create a directory to store output images output_dir = Path("outputs/image_transforms") diff --git a/examples/advanced/2_calculate_validation_loss.py b/examples/advanced/2_calculate_validation_loss.py index 47b4dd02..80c9f3a8 100644 --- a/examples/advanced/2_calculate_validation_loss.py +++ b/examples/advanced/2_calculate_validation_loss.py @@ -26,7 +26,10 @@ import math import torch -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +from lerobot.common.datasets.lerobot_dataset import ( + LeRobotDataset, + LeRobotDatasetMetadata, +) from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy diff --git a/examples/port_datasets/pusht_zarr.py b/examples/port_datasets/pusht_zarr.py new file mode 100644 index 00000000..6766ac83 --- /dev/null +++ b/examples/port_datasets/pusht_zarr.py @@ -0,0 +1,228 @@ +import shutil +from pathlib import Path + +import numpy as np +import torch + +from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset +from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw + +PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface." +PUSHT_FEATURES = { + "observation.state": { + "dtype": "float32", + "shape": (2,), + "names": { + "axes": ["x", "y"], + }, + }, + "action": { + "dtype": "float32", + "shape": (2,), + "names": { + "axes": ["x", "y"], + }, + }, + "next.reward": { + "dtype": "float32", + "shape": (1,), + "names": None, + }, + "next.success": { + "dtype": "bool", + "shape": (1,), + "names": None, + }, + "observation.environment_state": { + "dtype": "float32", + "shape": (16,), + "names": [ + "keypoints", + ], + }, + "observation.image": { + "dtype": None, + "shape": (3, 96, 96), + "names": [ + "channel", + "height", + "width", + ], + }, +} + + +def build_features(mode: str) -> dict: + features = PUSHT_FEATURES + if mode == "keypoints": + features.pop("observation.image") + else: + features.pop("observation.environment_state") + features["observation.image"]["dtype"] = mode + + return features + + +def load_raw_dataset(zarr_path: Path): + try: + from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import ( + ReplayBuffer as DiffusionPolicyReplayBuffer, + ) + except ModuleNotFoundError as e: + print( + "`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`" + ) + raise e + + zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path) + return zarr_data + + +def calculate_coverage(zarr_data): + try: + import pymunk + from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely + except ModuleNotFoundError as e: + print( + "`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`" + ) + raise e + + block_pos = zarr_data["state"][:, 2:4] + block_angle = zarr_data["state"][:, 4] + + num_frames = len(block_pos) + + coverage = np.zeros((num_frames,)) + # 8 keypoints with 2 coords each + keypoints = np.zeros((num_frames, 16)) + + # Set x, y, theta (in radians) + goal_pos_angle = np.array([256, 256, np.pi / 4]) + goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle) + + for i in range(num_frames): + space = pymunk.Space() + space.gravity = 0, 0 + space.damping = 0 + + # Add walls. + walls = [ + PushTEnv.add_segment(space, (5, 506), (5, 5), 2), + PushTEnv.add_segment(space, (5, 5), (506, 5), 2), + PushTEnv.add_segment(space, (506, 5), (506, 506), 2), + PushTEnv.add_segment(space, (5, 506), (506, 506), 2), + ] + space.add(*walls) + + block_body, block_shapes = PushTEnv.add_tee( + space, block_pos[i].tolist(), block_angle[i].item() + ) + goal_geom = pymunk_to_shapely(goal_body, block_body.shapes) + block_geom = pymunk_to_shapely(block_body, block_body.shapes) + intersection_area = goal_geom.intersection(block_geom).area + goal_area = goal_geom.area + coverage[i] = intersection_area / goal_area + keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten()) + + return coverage, keypoints + + +def calculate_success(coverage: float, success_threshold: float): + return coverage > success_threshold + + +def calculate_reward(coverage: float, success_threshold: float): + return np.clip(coverage / success_threshold, 0, 1) + + +def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = True): + if mode not in ["video", "image", "keypoints"]: + raise ValueError(mode) + + if (LEROBOT_HOME / repo_id).exists(): + shutil.rmtree(LEROBOT_HOME / repo_id) + + if not raw_dir.exists(): + download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw") + + zarr_data = load_raw_dataset(zarr_path=raw_dir / "pusht_cchi_v7_replay.zarr") + + env_state = zarr_data["state"][:] + agent_pos = env_state[:, :2] + + action = zarr_data["action"][:] + image = zarr_data["img"] # (b, h, w, c) + + episode_data_index = { + "from": np.concatenate(([0], zarr_data.meta["episode_ends"][:-1])), + "to": zarr_data.meta["episode_ends"], + } + + # Calculate success and reward based on the overlapping area + # of the T-object and the T-area. + coverage, keypoints = calculate_coverage(zarr_data) + success = calculate_success(coverage, success_threshold=0.95) + reward = calculate_reward(coverage, success_threshold=0.95) + + features = build_features(mode) + dataset = LeRobotDataset.create( + repo_id=repo_id, + fps=10, + robot_type="2d pointer", + features=features, + image_writer_threads=4, + ) + episodes = range(len(episode_data_index["from"])) + for ep_idx in episodes: + from_idx = episode_data_index["from"][ep_idx] + to_idx = episode_data_index["to"][ep_idx] + num_frames = to_idx - from_idx + + for frame_idx in range(num_frames): + i = from_idx + frame_idx + frame = { + "action": torch.from_numpy(action[i]), + # Shift reward and success by +1 until the last item of the episode + "next.reward": reward[i + (frame_idx < num_frames - 1)], + "next.success": success[i + (frame_idx < num_frames - 1)], + } + + frame["observation.state"] = torch.from_numpy(agent_pos[i]) + + if mode == "keypoints": + frame["observation.environment_state"] = torch.from_numpy(keypoints[i]) + else: + frame["observation.image"] = torch.from_numpy(image[i]) + + dataset.add_frame(frame) + + dataset.save_episode(task=PUSHT_TASK) + + dataset.consolidate() + + if push_to_hub: + dataset.push_to_hub() + + +if __name__ == "__main__": + # To try this script, modify the repo id with your own HuggingFace user (e.g cadene/pusht) + repo_id = "lerobot/pusht" + + modes = ["video", "image", "keypoints"] + # Uncomment if you want to try with a specific mode + # modes = ["video"] + # modes = ["image"] + # modes = ["keypoints"] + + raw_dir = Path("data/lerobot-raw/pusht_raw") + for mode in modes: + if mode in ["image", "keypoints"]: + repo_id += f"_{mode}" + + # download and load raw dataset, create LeRobotDataset, populate it, push to hub + main(raw_dir, repo_id=repo_id, mode=mode) + + # Uncomment if you want to load the local dataset and explore it + # dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True) + # breakpoint() diff --git a/lerobot/__init__.py b/lerobot/__init__.py index d61e4853..dec96226 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -164,7 +164,11 @@ available_real_world_datasets = [ ] available_datasets = sorted( - set(itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)) + set( + itertools.chain( + *available_datasets_per_env.values(), available_real_world_datasets + ) + ) ) # lists all available policies from `lerobot/common/policies` @@ -205,9 +209,13 @@ available_policies_per_env = { "aloha_real": ["act_aloha_real"], } -env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks] +env_task_pairs = [ + (env, task) for env, tasks in available_tasks_per_env.items() for task in tasks +] env_dataset_pairs = [ - (env, dataset) for env, datasets in available_datasets_per_env.items() for dataset in datasets + (env, dataset) + for env, datasets in available_datasets_per_env.items() + for dataset in datasets ] env_dataset_policy_triplets = [ (env, dataset, policy) diff --git a/lerobot/common/datasets/image_writer.py b/lerobot/common/datasets/image_writer.py index 6fc0ee2f..2dd685eb 100644 --- a/lerobot/common/datasets/image_writer.py +++ b/lerobot/common/datasets/image_writer.py @@ -127,7 +127,9 @@ class AsyncImageWriter: self._stopped = False if num_threads <= 0 and num_processes <= 0: - raise ValueError("Number of threads and processes must be greater than zero.") + raise ValueError( + "Number of threads and processes must be greater than zero." + ) if self.num_processes == 0: # Use threading @@ -141,12 +143,16 @@ class AsyncImageWriter: # Use multiprocessing self.queue = multiprocessing.JoinableQueue() for _ in range(self.num_processes): - p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads)) + p = multiprocessing.Process( + target=worker_process, args=(self.queue, self.num_threads) + ) p.daemon = True p.start() self.processes.append(p) - def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path): + def save_image( + self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path + ): if isinstance(image, torch.Tensor): # Convert tensor to numpy array to minimize main process time image = image.cpu().numpy() diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 09615767..61fd6cc5 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -139,7 +139,9 @@ class LeRobotDatasetMetadata: def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: ep_chunk = self.get_episode_chunk(ep_index) - fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index) + fpath = self.video_path.format( + episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index + ) return Path(fpath) def get_episode_chunk(self, ep_index: int) -> int: @@ -183,7 +185,11 @@ class LeRobotDatasetMetadata: @property def camera_keys(self) -> list[str]: """Keys to access visual modalities (regardless of their storage method).""" - return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]] + return [ + key + for key, ft in self.features.items() + if ft["dtype"] in ["video", "image"] + ] @property def names(self) -> dict[str, list | dict]: @@ -285,7 +291,9 @@ class LeRobotDatasetMetadata: """ for key in self.video_keys: if not self.features[key].get("info", None): - video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key) + video_path = self.root / self.get_video_file_path( + ep_index=0, vid_key=key + ) self.info["features"][key]["info"] = get_video_info(video_path) def __repr__(self): @@ -619,7 +627,10 @@ class LeRobotDataset(torch.utils.data.Dataset): path = str(self.root / "data") hf_dataset = load_dataset("parquet", data_dir=path, split="train") else: - files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes] + files = [ + str(self.root / self.meta.get_data_file_path(ep_idx)) + for ep_idx in self.episodes + ] hf_dataset = load_dataset("parquet", data_files=files, split="train") # TODO(aliberts): hf_dataset.set_format("torch") @@ -643,12 +654,20 @@ class LeRobotDataset(torch.utils.data.Dataset): @property def num_frames(self) -> int: """Number of frames in selected episodes.""" - return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames + return ( + len(self.hf_dataset) + if self.hf_dataset is not None + else self.meta.total_frames + ) @property def num_episodes(self) -> int: """Number of episodes selected.""" - return len(self.episodes) if self.episodes is not None else self.meta.total_episodes + return ( + len(self.episodes) + if self.episodes is not None + else self.meta.total_episodes + ) @property def features(self) -> dict[str, dict]: @@ -662,16 +681,24 @@ class LeRobotDataset(torch.utils.data.Dataset): else: return get_hf_features_from_features(self.features) - def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]: + def _get_query_indices( + self, idx: int, ep_idx: int + ) -> tuple[dict[str, list[int | bool]]]: ep_start = self.episode_data_index["from"][ep_idx] ep_end = self.episode_data_index["to"][ep_idx] query_indices = { - key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx] + key: [ + max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) + for delta in delta_idx + ] for key, delta_idx in self.delta_indices.items() } padding = { # Pad values outside of current episode range f"{key}_is_pad": torch.BoolTensor( - [(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx] + [ + (idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) + for delta in delta_idx + ] ) for key, delta_idx in self.delta_indices.items() } @@ -771,13 +798,17 @@ class LeRobotDataset(torch.utils.data.Dataset): ep_buffer[key] = current_ep_idx if key == "episode_index" else [] return ep_buffer - def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path: + def _get_image_file_path( + self, episode_index: int, image_key: str, frame_index: int + ) -> Path: fpath = DEFAULT_IMAGE_PATH.format( image_key=image_key, episode_index=episode_index, frame_index=frame_index ) return self.root / fpath - def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None: + def _save_image( + self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path + ) -> None: if self.image_writer is None: if isinstance(image, torch.Tensor): image = image.cpu().numpy() @@ -803,7 +834,9 @@ class LeRobotDataset(torch.utils.data.Dataset): # Automatically add frame_index and timestamp to episode buffer frame_index = self.episode_buffer["size"] - timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps + timestamp = ( + frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps + ) self.episode_buffer["frame_index"].append(frame_index) self.episode_buffer["timestamp"].append(timestamp) @@ -821,7 +854,9 @@ class LeRobotDataset(torch.utils.data.Dataset): if self.features[key]["dtype"] in ["image", "video"]: img_path = self._get_image_file_path( - episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index + episode_index=self.episode_buffer["episode_index"], + image_key=key, + frame_index=frame_index, ) if frame_index == 0: img_path.parent.mkdir(parents=True, exist_ok=True) @@ -1132,7 +1167,13 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): def features(self) -> datasets.Features: features = {} for dataset in self._datasets: - features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features}) + features.update( + { + k: v + for k, v in dataset.hf_features.items() + if k not in self.disabled_features + } + ) return features @property @@ -1193,7 +1234,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): continue break else: - raise AssertionError("We expect the loop to break out as long as the index is within bounds.") + raise AssertionError( + "We expect the loop to break out as long as the index is within bounds." + ) item = self._datasets[dataset_idx][idx - start_idx] item["dataset_index"] = torch.tensor(dataset_idx) for data_key in self.disabled_features: diff --git a/lerobot/common/datasets/online_buffer.py b/lerobot/common/datasets/online_buffer.py index d907e468..e31206fa 100644 --- a/lerobot/common/datasets/online_buffer.py +++ b/lerobot/common/datasets/online_buffer.py @@ -131,7 +131,9 @@ class OnlineBuffer(torch.utils.data.Dataset): else: self._delta_timestamps = None - def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]: + def _make_data_spec( + self, data_spec: dict[str, Any], buffer_capacity: int + ) -> dict[str, dict[str, Any]]: """Makes the data spec for np.memmap.""" if any(k.startswith("_") for k in data_spec): raise ValueError( @@ -154,14 +156,32 @@ class OnlineBuffer(torch.utils.data.Dataset): OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()}, # Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied # with real data rather than the dummy initialization. - OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)}, - OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)}, - OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)}, - OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)}, - OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)}, + OnlineBuffer.OCCUPANCY_MASK_KEY: { + "dtype": np.dtype("?"), + "shape": (buffer_capacity,), + }, + OnlineBuffer.INDEX_KEY: { + "dtype": np.dtype("int64"), + "shape": (buffer_capacity,), + }, + OnlineBuffer.FRAME_INDEX_KEY: { + "dtype": np.dtype("int64"), + "shape": (buffer_capacity,), + }, + OnlineBuffer.EPISODE_INDEX_KEY: { + "dtype": np.dtype("int64"), + "shape": (buffer_capacity,), + }, + OnlineBuffer.TIMESTAMP_KEY: { + "dtype": np.dtype("float64"), + "shape": (buffer_capacity,), + }, } for k, v in data_spec.items(): - complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])} + complete_data_spec[k] = { + "dtype": v["dtype"], + "shape": (buffer_capacity, *v["shape"]), + } return complete_data_spec def add_data(self, data: dict[str, np.ndarray]): @@ -188,7 +208,9 @@ class OnlineBuffer(torch.utils.data.Dataset): # Shift the incoming indices if necessary. if self.num_frames > 0: - last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1] + last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][ + next_index - 1 + ] last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1] data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1 data[OnlineBuffer.INDEX_KEY] += last_data_index + 1 @@ -223,7 +245,11 @@ class OnlineBuffer(torch.utils.data.Dataset): @property def num_episodes(self) -> int: return len( - np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]) + np.unique( + self._data[OnlineBuffer.EPISODE_INDEX_KEY][ + self._data[OnlineBuffer.OCCUPANCY_MASK_KEY] + ] + ) ) @property @@ -261,7 +287,9 @@ class OnlineBuffer(torch.utils.data.Dataset): self._data[OnlineBuffer.OCCUPANCY_MASK_KEY], ) )[0] - episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices] + episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][ + episode_data_indices + ] for data_key in self.delta_timestamps: # Note: The logic in this loop is copied from `load_previous_and_future_frames`. @@ -278,7 +306,8 @@ class OnlineBuffer(torch.utils.data.Dataset): # Check violated query timestamps are all outside the episode range. assert ( - (query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad]) + (query_ts[is_pad] < episode_timestamps[0]) + | (episode_timestamps[-1] < query_ts[is_pad]) ).all(), ( f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}" ") inside the episode range." @@ -293,7 +322,9 @@ class OnlineBuffer(torch.utils.data.Dataset): def get_data_by_key(self, key: str) -> torch.Tensor: """Returns all data for a given data key as a Tensor.""" - return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]) + return torch.from_numpy( + self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]] + ) def compute_sampler_weights( @@ -324,13 +355,19 @@ def compute_sampler_weights( - Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not included here to avoid adding complexity. """ - if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0): - raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.") + if len(offline_dataset) == 0 and ( + online_dataset is None or len(online_dataset) == 0 + ): + raise ValueError( + "At least one of `offline_dataset` or `online_dataset` should be contain data." + ) if (online_dataset is None) ^ (online_sampling_ratio is None): raise ValueError( "`online_dataset` and `online_sampling_ratio` must be provided together or not at all." ) - offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio + offline_sampling_ratio = ( + 0 if online_sampling_ratio is None else 1 - online_sampling_ratio + ) weights = [] diff --git a/lerobot/common/datasets/push_dataset_to_hub/utils.py b/lerobot/common/datasets/push_dataset_to_hub/utils.py index ebcf87f7..13997c81 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/utils.py +++ b/lerobot/common/datasets/push_dataset_to_hub/utils.py @@ -45,7 +45,9 @@ def concatenate_episodes(ep_dicts): return data_dict -def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4): +def save_images_concurrently( + imgs_array: numpy.array, out_dir: Path, max_workers: int = 4 +): out_dir = Path(out_dir) out_dir.mkdir(parents=True, exist_ok=True) @@ -55,7 +57,10 @@ def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers num_images = len(imgs_array) with ThreadPoolExecutor(max_workers=max_workers) as executor: - [executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)] + [ + executor.submit(save_image, imgs_array[i], i, out_dir) + for i in range(num_images) + ] def get_default_encoding() -> dict: @@ -64,7 +69,8 @@ def get_default_encoding() -> dict: return { k: v.default for k, v in signature.parameters.items() - if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"] + if v.default is not inspect.Parameter.empty + and k in ["vcodec", "pix_fmt", "g", "crf"] } @@ -77,7 +83,9 @@ def check_repo_id(repo_id: str) -> None: # TODO(aliberts): remove -def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]: +def calculate_episode_data_index( + hf_dataset: datasets.Dataset, +) -> Dict[str, torch.Tensor]: """ Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset. diff --git a/lerobot/common/datasets/sampler.py b/lerobot/common/datasets/sampler.py index 2f6c15c1..53d0e2e4 100644 --- a/lerobot/common/datasets/sampler.py +++ b/lerobot/common/datasets/sampler.py @@ -43,7 +43,10 @@ class EpisodeAwareSampler: ): if episode_indices_to_use is None or episode_idx in episode_indices_to_use: indices.extend( - range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames) + range( + start_index.item() + drop_n_first_frames, + end_index.item() - drop_n_last_frames, + ) ) self.indices = indices diff --git a/lerobot/common/datasets/transforms.py b/lerobot/common/datasets/transforms.py index 720c939b..401120da 100644 --- a/lerobot/common/datasets/transforms.py +++ b/lerobot/common/datasets/transforms.py @@ -58,7 +58,9 @@ class RandomSubsetApply(Transform): elif not isinstance(n_subset, int): raise TypeError("n_subset should be an int or None") elif not (1 <= n_subset <= len(transforms)): - raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]") + raise ValueError( + f"n_subset should be in the interval [1, {len(transforms)}]" + ) self.transforms = transforms total = sum(p) @@ -119,16 +121,22 @@ class SharpnessJitter(Transform): def _check_input(self, sharpness): if isinstance(sharpness, (int, float)): if sharpness < 0: - raise ValueError("If sharpness is a single number, it must be non negative.") + raise ValueError( + "If sharpness is a single number, it must be non negative." + ) sharpness = [1.0 - sharpness, 1.0 + sharpness] sharpness[0] = max(sharpness[0], 0.0) elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2: sharpness = [float(v) for v in sharpness] else: - raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.") + raise TypeError( + f"{sharpness=} should be a single number or a sequence with length 2." + ) if not 0.0 <= sharpness[0] <= sharpness[1]: - raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.") + raise ValueError( + f"sharpnesss values should be between (0., inf), but got {sharpness}." + ) return float(sharpness[0]), float(sharpness[1]) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 7e297b35..1050f6eb 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -52,9 +52,15 @@ STATS_PATH = "meta/stats.json" EPISODES_STATS_PATH = "meta/episodes_stats.jsonl" TASKS_PATH = "meta/tasks.jsonl" -DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" -DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet" -DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png" +DEFAULT_VIDEO_PATH = ( + "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" +) +DEFAULT_PARQUET_PATH = ( + "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet" +) +DEFAULT_IMAGE_PATH = ( + "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png" +) DATASET_CARD_TEMPLATE = """ --- @@ -540,7 +546,10 @@ def check_timestamps_sync( def check_delta_timestamps( - delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True + delta_timestamps: dict[str, list[float]], + fps: int, + tolerance_s: float, + raise_value_error: bool = True, ) -> bool: """This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance. This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be @@ -548,10 +557,14 @@ def check_delta_timestamps( """ outside_tolerance = {} for key, delta_ts in delta_timestamps.items(): - within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts] + within_tolerance = [ + abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts + ] if not all(within_tolerance): outside_tolerance[key] = [ - ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within + ts + for ts, is_within in zip(delta_ts, within_tolerance, strict=True) + if not is_within ] if len(outside_tolerance) > 0: @@ -569,7 +582,9 @@ def check_delta_timestamps( return True -def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]: +def get_delta_indices( + delta_timestamps: dict[str, list[float]], fps: int +) -> dict[str, list[int]]: delta_indices = {} for key, delta_ts in delta_timestamps.items(): delta_indices[key] = [round(d * fps) for d in delta_ts] @@ -634,7 +649,9 @@ def create_lerobot_dataset_card( ], ) - card_template = (importlib.resources.files("lerobot.common.datasets") / "card_template.md").read_text() + card_template = ( + importlib.resources.files("lerobot.common.datasets") / "card_template.md" + ).read_text() return DatasetCard.from_template( card_data=card_data, diff --git a/lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py b/lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py index 99ab2cbf..58b334bc 100644 --- a/lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py +++ b/lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py @@ -118,7 +118,10 @@ DATASETS = { "single_task": "Place the battery into the slot of the remote controller.", **ALOHA_STATIC_INFO, }, - "aloha_static_candy": {"single_task": "Pick up the candy and unwrap it.", **ALOHA_STATIC_INFO}, + "aloha_static_candy": { + "single_task": "Pick up the candy and unwrap it.", + **ALOHA_STATIC_INFO, + }, "aloha_static_coffee": { "single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.", **ALOHA_STATIC_INFO, @@ -167,13 +170,22 @@ DATASETS = { "single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.", **ALOHA_STATIC_INFO, }, - "aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO}, - "aloha_sim_insertion_scripted": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO}, + "aloha_static_ziploc_slide": { + "single_task": "Slide open the ziploc bag.", + **ALOHA_STATIC_INFO, + }, + "aloha_sim_insertion_scripted": { + "single_task": "Insert the peg into the socket.", + **ALOHA_STATIC_INFO, + }, "aloha_sim_insertion_scripted_image": { "single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO, }, - "aloha_sim_insertion_human": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO}, + "aloha_sim_insertion_human": { + "single_task": "Insert the peg into the socket.", + **ALOHA_STATIC_INFO, + }, "aloha_sim_insertion_human_image": { "single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO, @@ -194,10 +206,19 @@ DATASETS = { "single_task": "Pick up the cube with the right arm and transfer it to the left arm.", **ALOHA_STATIC_INFO, }, - "pusht": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO}, - "pusht_image": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO}, + "pusht": { + "single_task": "Push the T-shaped block onto the T-shaped target.", + **PUSHT_INFO, + }, + "pusht_image": { + "single_task": "Push the T-shaped block onto the T-shaped target.", + **PUSHT_INFO, + }, "unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO}, - "unitreeh1_rearrange_objects": {"single_task": "Put the object into the bin.", **UNITREEH_INFO}, + "unitreeh1_rearrange_objects": { + "single_task": "Put the object into the bin.", + **UNITREEH_INFO, + }, "unitreeh1_two_robot_greeting": { "single_task": "Greet the other robot with a high five.", **UNITREEH_INFO, @@ -207,13 +228,31 @@ DATASETS = { **UNITREEH_INFO, }, "xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO}, - "xarm_lift_medium_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO}, - "xarm_lift_medium_replay": {"single_task": "Pick up the cube and lift it.", **XARM_INFO}, - "xarm_lift_medium_replay_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO}, + "xarm_lift_medium_image": { + "single_task": "Pick up the cube and lift it.", + **XARM_INFO, + }, + "xarm_lift_medium_replay": { + "single_task": "Pick up the cube and lift it.", + **XARM_INFO, + }, + "xarm_lift_medium_replay_image": { + "single_task": "Pick up the cube and lift it.", + **XARM_INFO, + }, "xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO}, - "xarm_push_medium_image": {"single_task": "Push the cube onto the target.", **XARM_INFO}, - "xarm_push_medium_replay": {"single_task": "Push the cube onto the target.", **XARM_INFO}, - "xarm_push_medium_replay_image": {"single_task": "Push the cube onto the target.", **XARM_INFO}, + "xarm_push_medium_image": { + "single_task": "Push the cube onto the target.", + **XARM_INFO, + }, + "xarm_push_medium_replay": { + "single_task": "Push the cube onto the target.", + **XARM_INFO, + }, + "xarm_push_medium_replay_image": { + "single_task": "Push the cube onto the target.", + **XARM_INFO, + }, "umi_cup_in_the_wild": { "single_task": "Put the cup on the plate.", "license": "apache-2.0", diff --git a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py index acf0282f..0fddeaf9 100644 --- a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py +++ b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py @@ -218,7 +218,9 @@ def get_features_from_hf_dataset( dtype = ft.feature.dtype shape = (ft.length,) motor_names = ( - robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)] + robot_config["names"][key] + if robot_config + else [f"motor_{i}" for i in range(ft.length)] ) assert len(motor_names) == shape[0] names = {"motors": motor_names} @@ -242,11 +244,15 @@ def get_features_from_hf_dataset( return features -def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]: +def add_task_index_by_episodes( + dataset: Dataset, tasks_by_episodes: dict +) -> tuple[Dataset, list[str]]: df = dataset.to_pandas() tasks = list(set(tasks_by_episodes.values())) tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)} - episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()} + episodes_to_task_index = { + ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items() + } df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int) features = dataset.features @@ -263,10 +269,19 @@ def add_task_index_from_tasks_col( # HACK: This is to clean some of the instructions in our version of Open X datasets prefix_to_clean = "tf.Tensor(b'" suffix_to_clean = "', shape=(), dtype=string)" - df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean) + df[tasks_col] = ( + df[tasks_col] + .str.removeprefix(prefix_to_clean) + .str.removesuffix(suffix_to_clean) + ) # Create task_index col - tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict() + tasks_by_episode = ( + df.groupby("episode_index")[tasks_col] + .unique() + .apply(lambda x: x.tolist()) + .to_dict() + ) tasks = df[tasks_col].unique().tolist() tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)} df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int) @@ -291,7 +306,9 @@ def split_parquet_by_episodes( for ep_chunk in range(total_chunks): ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes) - chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk) + chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format( + episode_chunk=ep_chunk + ) (output_dir / chunk_dir).mkdir(parents=True, exist_ok=True) for ep_idx in range(ep_chunk_start, ep_chunk_end): ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) @@ -323,7 +340,9 @@ def move_videos( videos_moved = False video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")] if len(video_files) == 0: - video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")] + video_files = [ + str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4") + ] videos_moved = True # Videos have already been moved assert len(video_files) == total_episodes * len(video_keys) @@ -354,7 +373,9 @@ def move_videos( target_path = DEFAULT_VIDEO_PATH.format( episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx ) - video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx) + video_file = V1_VIDEO_FILE.format( + video_key=vid_key, episode_index=ep_idx + ) if len(video_dirs) == 1: video_path = video_dirs[0] / video_file else: @@ -371,7 +392,9 @@ def move_videos( subprocess.run(["git", "push"], cwd=work_dir, check=True) -def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None: +def fix_lfs_video_files_tracking( + work_dir: Path, lfs_untracked_videos: list[str] +) -> None: """ HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case, there's no other option than to download the actual files and reupload them with lfs tracking. @@ -379,7 +402,12 @@ def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str] for i in range(0, len(lfs_untracked_videos), 100): files = lfs_untracked_videos[i : i + 100] try: - subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True) + subprocess.run( + ["git", "rm", "--cached", *files], + cwd=work_dir, + capture_output=True, + check=True, + ) except subprocess.CalledProcessError as e: print("git rm --cached ERROR:") print(e.stderr) @@ -390,10 +418,14 @@ def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str] subprocess.run(["git", "push"], cwd=work_dir, check=True) -def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None: +def fix_gitattributes( + work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path +) -> None: shutil.copyfile(clean_gittatributes, current_gittatributes) subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True) - subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True) + subprocess.run( + ["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True + ) subprocess.run(["git", "push"], cwd=work_dir, check=True) @@ -402,7 +434,17 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None: repo_url = f"https://huggingface.co/datasets/{repo_id}" env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files subprocess.run( - ["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)], + [ + "git", + "clone", + "--branch", + branch, + "--single-branch", + "--depth", + "1", + repo_url, + str(work_dir), + ], check=True, env=env, ) @@ -410,13 +452,19 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None: def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]: lfs_tracked_files = subprocess.run( - ["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True + ["git", "lfs", "ls-files", "-n"], + cwd=work_dir, + capture_output=True, + text=True, + check=True, ) lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines()) return [f for f in video_files if f not in lfs_tracked_files] -def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict: +def get_videos_info( + repo_id: str, local_dir: Path, video_keys: list[str], branch: str +) -> dict: # Assumes first episode video_files = [ DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0) @@ -424,7 +472,11 @@ def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch ] hub_api = HfApi() hub_api.snapshot_download( - repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files + repo_id=repo_id, + repo_type="dataset", + local_dir=local_dir, + revision=branch, + allow_patterns=video_files, ) videos_info_dict = {} for vid_key, vid_path in zip(video_keys, video_files, strict=True): @@ -451,7 +503,11 @@ def convert_dataset( hub_api = HfApi() hub_api.snapshot_download( - repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/" + repo_id=repo_id, + repo_type="dataset", + revision=v1, + local_dir=v1x_dir, + ignore_patterns="videos*/", ) branch = "main" if test_branch: @@ -483,19 +539,31 @@ def convert_dataset( if single_task: tasks_by_episodes = {ep_idx: single_task for ep_idx in episode_indices} dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes) - tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()} + tasks_by_episodes = { + ep_idx: [task] for ep_idx, task in tasks_by_episodes.items() + } elif tasks_path: tasks_by_episodes = load_json(tasks_path) - tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()} + tasks_by_episodes = { + int(ep_idx): task for ep_idx, task in tasks_by_episodes.items() + } dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes) - tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()} + tasks_by_episodes = { + ep_idx: [task] for ep_idx, task in tasks_by_episodes.items() + } elif tasks_col: - dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col) + dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col( + dataset, tasks_col + ) else: raise ValueError - assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks} - tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)] + assert set(tasks) == { + task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks + } + tasks = [ + {"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks) + ] write_jsonlines(tasks, v20_dir / TASKS_PATH) features["task_index"] = { "dtype": "int64", @@ -509,14 +577,25 @@ def convert_dataset( dataset = dataset.remove_columns(video_keys) clean_gitattr = Path( hub_api.hf_hub_download( - repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes" + repo_id=GITATTRIBUTES_REF, + repo_type="dataset", + local_dir=local_dir, + filename=".gitattributes", ) ).absolute() with tempfile.TemporaryDirectory() as tmp_video_dir: move_videos( - repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch + repo_id, + video_keys, + total_episodes, + total_chunks, + Path(tmp_video_dir), + clean_gitattr, + branch, ) - videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch) + videos_info = get_videos_info( + repo_id, v1x_dir, video_keys=video_keys, branch=branch + ) for key in video_keys: features[key]["shape"] = ( videos_info[key].pop("video.height"), @@ -524,15 +603,22 @@ def convert_dataset( videos_info[key].pop("video.channels"), ) features[key]["video_info"] = videos_info[key] - assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3) + assert math.isclose( + videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3 + ) if "encoding" in metadata_v1: - assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"] + assert ( + videos_info[key]["video.pix_fmt"] + == metadata_v1["encoding"]["pix_fmt"] + ) else: assert metadata_v1.get("video", 0) == 0 videos_info = None # Split data into 1 parquet file by episode - episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir) + episode_lengths = split_parquet_by_episodes( + dataset, total_episodes, total_chunks, v20_dir + ) if robot_config is not None: robot_type = robot_config.type @@ -543,7 +629,11 @@ def convert_dataset( # Episodes episodes = [ - {"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]} + { + "episode_index": ep_idx, + "tasks": tasks_by_episodes[ep_idx], + "length": episode_lengths[ep_idx], + } for ep_idx in episode_indices ] write_jsonlines(episodes, v20_dir / EPISODES_PATH) @@ -566,16 +656,27 @@ def convert_dataset( } write_json(metadata_v2_0, v20_dir / INFO_PATH) convert_stats_to_json(v1x_dir, v20_dir) - card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs) + card = create_lerobot_dataset_card( + tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs + ) with contextlib.suppress(EntryNotFoundError, HfHubHTTPError): - hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch) + hub_api.delete_folder( + repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch + ) with contextlib.suppress(EntryNotFoundError, HfHubHTTPError): - hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch) + hub_api.delete_folder( + repo_id=repo_id, + path_in_repo="meta_data", + repo_type="dataset", + revision=branch, + ) with contextlib.suppress(EntryNotFoundError, HfHubHTTPError): - hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch) + hub_api.delete_folder( + repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch + ) hub_api.upload_folder( repo_id=repo_id, diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index c38d570d..ee180d96 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -344,7 +344,9 @@ def get_audio_info(video_path: Path | str) -> dict: "json", str(video_path), ] - result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + result = subprocess.run( + ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True + ) if result.returncode != 0: raise RuntimeError(f"Error running ffprobe: {result.stderr}") @@ -358,7 +360,9 @@ def get_audio_info(video_path: Path | str) -> dict: "has_audio": True, "audio.channels": audio_stream_info.get("channels", None), "audio.codec": audio_stream_info.get("codec_name", None), - "audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None, + "audio.bit_rate": int(audio_stream_info["bit_rate"]) + if audio_stream_info.get("bit_rate") + else None, "audio.sample_rate": int(audio_stream_info["sample_rate"]) if audio_stream_info.get("sample_rate") else None, @@ -380,7 +384,9 @@ def get_video_info(video_path: Path | str) -> dict: "json", str(video_path), ] - result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + result = subprocess.run( + ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True + ) if result.returncode != 0: raise RuntimeError(f"Error running ffprobe: {result.stderr}") diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 96ee7448..f84ef681 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -70,7 +70,9 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g return env -def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None: +def make_maniskill_env( + cfg: DictConfig, n_envs: int | None = None +) -> gym.vector.VectorEnv | None: """Make ManiSkill3 gym environment""" from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv @@ -87,7 +89,9 @@ def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector # state should have the size of 25 # env = ConvertToLeRobotEnv(env, n_envs) # env = PixelWrapper(cfg, env, n_envs) - env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env) + env._max_episode_steps = env.max_episode_steps = ( + 50 # gym_utils.find_max_episode_steps_value(env) + ) env.unwrapped.metadata["render_fps"] = 20 return env @@ -114,7 +118,11 @@ class PixelWrapper(gym.Wrapper): def _get_obs(self, obs): frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2) self._frames.append(frame) - return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)} + return { + "pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to( + self.env.device + ) + } def reset(self, seed): obs, info = self.env.reset() # (seed=seed) @@ -148,7 +156,9 @@ class ConvertToLeRobotEnv(gym.Wrapper): images = torch.concat(images, axis=-1) # flatten the rest of the data which should just be state data - observation = common.flatten_state_dict(observation, use_torch=True, device=self.base_env.device) + observation = common.flatten_state_dict( + observation, use_torch=True, device=self.base_env.device + ) ret = dict() ret["state"] = observation ret["pixels"] = images diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 3a9cb2a5..b140270b 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -84,7 +84,9 @@ class Logger: pretrained_model_dir_name = "pretrained_model" training_state_file_name = "training_state.pth" - def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None): + def __init__( + self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None + ): """ Args: log_dir: The directory to save all logs and training outputs to. @@ -104,7 +106,9 @@ class Logger: enable_wandb = cfg.get("wandb", {}).get("enable", False) run_offline = not enable_wandb or not project if run_offline: - logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) + logging.info( + colored("Logs will be saved locally.", "yellow", attrs=["bold"]) + ) self._wandb = None else: os.environ["WANDB_SILENT"] = "true" @@ -130,7 +134,9 @@ class Logger: # Handle custom step key for rl asynchronous training. self._wandb_custom_step_key: set[str] | None = None print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) - logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") + logging.info( + f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}" + ) self._wandb = wandb @classmethod @@ -151,7 +157,9 @@ class Logger: """ return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name - def save_model(self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None): + def save_model( + self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None + ): """Save the weights of the Policy model using PyTorchModelHubMixin. The weights are saved in a folder called "pretrained_model" under the checkpoint directory. @@ -221,22 +229,30 @@ class Logger: else f"{self._group.replace(':', '_').replace('/', '_')}-{self._cfg.seed}-{identifier}" ) self.save_model( - checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name + checkpoint_dir / self.pretrained_model_dir_name, + policy, + wandb_artifact_name=wandb_artifact_name, + ) + self.save_training_state( + checkpoint_dir, train_step, optimizer, scheduler, interaction_step ) - self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler, interaction_step) os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir) - def load_last_training_state(self, optimizer: Optimizer | dict, scheduler: LRScheduler | None) -> int: + def load_last_training_state( + self, optimizer: Optimizer | dict, scheduler: LRScheduler | None + ) -> int: """ Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and random state, and return the global training step. """ - training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name) + training_state = torch.load( + self.last_checkpoint_dir / self.training_state_file_name + ) # For the case where the optimizer is a dictionary of optimizers (e.g., sac) if type(training_state["optimizer"]) is dict: - assert set(training_state["optimizer"].keys()) == set(optimizer.keys()), ( - "Optimizer dictionaries do not have the same keys during resume!" - ) + assert set(training_state["optimizer"].keys()) == set( + optimizer.keys() + ), "Optimizer dictionaries do not have the same keys during resume!" for k, v in training_state["optimizer"].items(): optimizer[k].load_state_dict(v) else: @@ -248,10 +264,18 @@ class Logger: "The checkpoint contains a scheduler state_dict, but no LRScheduler was provided." ) # Small hack to get the expected keys: use `get_global_random_state`. - set_global_random_state({k: training_state[k] for k in get_global_random_state()}) + set_global_random_state( + {k: training_state[k] for k in get_global_random_state()} + ) return training_state["step"] - def log_dict(self, d, step: int | None = None, mode="train", custom_step_key: str | None = None): + def log_dict( + self, + d, + step: int | None = None, + mode="train", + custom_step_key: str | None = None, + ): """Log a dictionary of metrics to WandB.""" assert mode in {"train", "eval"} # TODO(alexander-soare): Add local text log. @@ -280,12 +304,20 @@ class Logger: continue # Do not log the custom step key itself. - if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key: + if ( + self._wandb_custom_step_key is not None + and k in self._wandb_custom_step_key + ): continue if custom_step_key is not None: value_custom_step = d[custom_step_key] - self._wandb.log({f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step}) + self._wandb.log( + { + f"{mode}/{k}": v, + f"{mode}/{custom_step_key}": value_custom_step, + } + ) continue self._wandb.log(data={f"{mode}/{k}": v}, step=step) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 72d4df03..3dec1584 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -74,7 +74,9 @@ class ACTPolicy(PreTrainedPolicy): self.model = ACT(config) if config.temporal_ensemble_coeff is not None: - self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size) + self.temporal_ensembler = ACTTemporalEnsembler( + config.temporal_ensemble_coeff, config.chunk_size + ) self.reset() @@ -153,7 +155,8 @@ class ACTPolicy(PreTrainedPolicy): actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) l1_loss = ( - F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) + F.l1_loss(batch["action"], actions_hat, reduction="none") + * ~batch["action_is_pad"].unsqueeze(-1) ).mean() loss_dict = {"l1_loss": l1_loss.item()} @@ -163,7 +166,12 @@ class ACTPolicy(PreTrainedPolicy): # KL-divergence per batch element, then take the mean over the batch. # (See App. B of https://arxiv.org/abs/1312.6114 for more details). mean_kld = ( - (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() + ( + -0.5 + * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp()) + ) + .sum(-1) + .mean() ) loss_dict["kld_loss"] = mean_kld.item() loss = l1_loss + mean_kld * self.config.kl_weight @@ -217,7 +225,9 @@ class ACTTemporalEnsembler: ``` """ self.chunk_size = chunk_size - self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)) + self.ensemble_weights = torch.exp( + -temporal_ensemble_coeff * torch.arange(chunk_size) + ) self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0) self.reset() @@ -233,7 +243,9 @@ class ACTTemporalEnsembler: time steps, and pop/return the next batch of actions in the sequence. """ self.ensemble_weights = self.ensemble_weights.to(device=actions.device) - self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device) + self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to( + device=actions.device + ) if self.ensembled_actions is None: # Initializes `self._ensembled_action` to the sequence of actions predicted during the first # time step of the episode. @@ -241,19 +253,34 @@ class ACTTemporalEnsembler: # Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor # operations later. self.ensembled_actions_count = torch.ones( - (self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device + (self.chunk_size, 1), + dtype=torch.long, + device=self.ensembled_actions.device, ) else: # self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute # the online update for those entries. - self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1] - self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count] - self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count] - self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size) + self.ensembled_actions *= self.ensemble_weights_cumsum[ + self.ensembled_actions_count - 1 + ] + self.ensembled_actions += ( + actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count] + ) + self.ensembled_actions /= self.ensemble_weights_cumsum[ + self.ensembled_actions_count + ] + self.ensembled_actions_count = torch.clamp( + self.ensembled_actions_count + 1, max=self.chunk_size + ) # The last action, which has no prior online average, needs to get concatenated onto the end. - self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1) + self.ensembled_actions = torch.cat( + [self.ensembled_actions, actions[:, -1:]], dim=1 + ) self.ensembled_actions_count = torch.cat( - [self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])] + [ + self.ensembled_actions_count, + torch.ones_like(self.ensembled_actions_count[-1:]), + ] ) # "Consume" the first action. action, self.ensembled_actions, self.ensembled_actions_count = ( @@ -319,7 +346,9 @@ class ACT(nn.Module): config.dim_model, ) # Projection layer from the VAE encoder's output to the latent distribution's parameter space. - self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2) + self.vae_encoder_latent_output_proj = nn.Linear( + config.dim_model, config.latent_dim * 2 + ) # Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch # dimension. num_input_token_encoder = 1 + config.chunk_size @@ -327,20 +356,28 @@ class ACT(nn.Module): num_input_token_encoder += 1 self.register_buffer( "vae_encoder_pos_enc", - create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0), + create_sinusoidal_pos_embedding( + num_input_token_encoder, config.dim_model + ).unsqueeze(0), ) # Backbone for image feature extraction. if self.config.image_features: backbone_model = getattr(torchvision.models, config.vision_backbone)( - replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation], + replace_stride_with_dilation=[ + False, + False, + config.replace_final_stride_with_dilation, + ], weights=config.pretrained_backbone_weights, norm_layer=FrozenBatchNorm2d, ) # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final # feature map). # Note: The forward method of this returns a dict: {"feature_map": output}. - self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) + self.backbone = IntermediateLayerGetter( + backbone_model, return_layers={"layer4": "feature_map"} + ) # Transformer (acts as VAE decoder when training with the variational objective). self.encoder = ACTEncoder(config) @@ -386,7 +423,9 @@ class ACT(nn.Module): if p.dim() > 1: nn.init.xavier_uniform_(p) - def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]: + def forward( + self, batch: dict[str, Tensor] + ) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]: """A forward pass through the Action Chunking Transformer (with optional VAE encoder). `batch` should have the following structure: @@ -424,7 +463,9 @@ class ACT(nn.Module): if self.config.robot_state_feature: robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]) robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) - action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) + action_embed = self.vae_encoder_action_input_proj( + batch["action"] + ) # (B, S, D) if self.config.robot_state_feature: vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D) @@ -465,20 +506,24 @@ class ACT(nn.Module): # When not using the VAE encoder, we set the latent to be all zeros. mu = log_sigma_x2 = None # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer - latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to( - batch["observation.state"].device - ) + latent_sample = torch.zeros( + [batch_size, self.config.latent_dim], dtype=torch.float32 + ).to(batch["observation.state"].device) # Prepare transformer encoder inputs. encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)] - encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1)) + encoder_in_pos_embed = list( + self.encoder_1d_feature_pos_embed.weight.unsqueeze(1) + ) # Robot state token. if self.config.robot_state_feature: encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"])) # Environment state token. if self.config.env_state_feature: encoder_in_tokens.append( - self.encoder_env_state_input_proj(batch["observation.environment_state"]) + self.encoder_env_state_input_proj( + batch["observation.environment_state"] + ) ) # Camera observation features and positional embeddings. @@ -535,12 +580,21 @@ class ACTEncoder(nn.Module): def __init__(self, config: ACTConfig, is_vae_encoder: bool = False): super().__init__() self.is_vae_encoder = is_vae_encoder - num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers - self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)]) + num_layers = ( + config.n_vae_encoder_layers + if self.is_vae_encoder + else config.n_encoder_layers + ) + self.layers = nn.ModuleList( + [ACTEncoderLayer(config) for _ in range(num_layers)] + ) self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity() def forward( - self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None + self, + x: Tensor, + pos_embed: Tensor | None = None, + key_padding_mask: Tensor | None = None, ) -> Tensor: for layer in self.layers: x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask) @@ -551,7 +605,9 @@ class ACTEncoder(nn.Module): class ACTEncoderLayer(nn.Module): def __init__(self, config: ACTConfig): super().__init__() - self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) + self.self_attn = nn.MultiheadAttention( + config.dim_model, config.n_heads, dropout=config.dropout + ) # Feed forward layers. self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward) @@ -566,7 +622,9 @@ class ACTEncoderLayer(nn.Module): self.activation = get_activation_fn(config.feedforward_activation) self.pre_norm = config.pre_norm - def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor: + def forward( + self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None + ) -> Tensor: skip = x if self.pre_norm: x = self.norm1(x) @@ -591,7 +649,9 @@ class ACTDecoder(nn.Module): def __init__(self, config: ACTConfig): """Convenience module for running multiple decoder layers followed by normalization.""" super().__init__() - self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]) + self.layers = nn.ModuleList( + [ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)] + ) self.norm = nn.LayerNorm(config.dim_model) def forward( @@ -603,7 +663,10 @@ class ACTDecoder(nn.Module): ) -> Tensor: for layer in self.layers: x = layer( - x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed + x, + encoder_out, + decoder_pos_embed=decoder_pos_embed, + encoder_pos_embed=encoder_pos_embed, ) if self.norm is not None: x = self.norm(x) @@ -613,8 +676,12 @@ class ACTDecoder(nn.Module): class ACTDecoderLayer(nn.Module): def __init__(self, config: ACTConfig): super().__init__() - self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) - self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) + self.self_attn = nn.MultiheadAttention( + config.dim_model, config.n_heads, dropout=config.dropout + ) + self.multihead_attn = nn.MultiheadAttention( + config.dim_model, config.n_heads, dropout=config.dropout + ) # Feed forward layers. self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward) @@ -655,7 +722,9 @@ class ACTDecoderLayer(nn.Module): if self.pre_norm: x = self.norm1(x) q = k = self.maybe_add_pos_embed(x, decoder_pos_embed) - x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights + x = self.self_attn(q, k, value=x)[ + 0 + ] # select just the output, not the attention weights x = skip + self.dropout1(x) if self.pre_norm: skip = x @@ -692,9 +761,14 @@ def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tenso """ def get_position_angle_vec(position): - return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)] + return [ + position / np.power(10000, 2 * (hid_j // 2) / dimension) + for hid_j in range(dimension) + ] - sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)]) + sinusoid_table = np.array( + [get_position_angle_vec(pos_i) for pos_i in range(num_positions)] + ) sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 return torch.from_numpy(sinusoid_table).float() @@ -739,7 +813,9 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module): x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi inverse_frequency = self._temperature ** ( - 2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension + 2 + * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) + / self.dimension ) x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1) @@ -747,9 +823,15 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module): # Note: this stack then flatten operation results in interleaved sine and cosine terms. # pos_embed_x and pos_embed_y are (1, H, W, C // 2). - pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3) - pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3) - pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W) + pos_embed_x = torch.stack( + (x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1 + ).flatten(3) + pos_embed_y = torch.stack( + (y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1 + ).flatten(3) + pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute( + 0, 3, 1, 2 + ) # (1, C, H, W) return pos_embed diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 9ecadcb0..d331dddf 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -132,7 +132,11 @@ class DiffusionPolicy(PreTrainedPolicy): if len(self._queues["action"]) == 0: # stack n latest observations from the queue - batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} + batch = { + k: torch.stack(list(self._queues[k]), dim=1) + for k in batch + if k in self._queues + } actions = self.diffusion.generate_actions(batch) # TODO(rcadene): make above methods return output dictionary? @@ -189,7 +193,9 @@ class DiffusionModel(nn.Module): if self.config.env_state_feature: global_cond_dim += self.config.env_state_feature.shape[0] - self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps) + self.unet = DiffusionConditionalUnet1d( + config, global_cond_dim=global_cond_dim * config.n_obs_steps + ) self.noise_scheduler = _make_noise_scheduler( config.noise_scheduler_type, @@ -209,7 +215,10 @@ class DiffusionModel(nn.Module): # ========= inference ============ def conditional_sample( - self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None + self, + batch_size: int, + global_cond: Tensor | None = None, + generator: torch.Generator | None = None, ) -> Tensor: device = get_device_from_parameters(self) dtype = get_dtype_from_parameters(self) @@ -232,7 +241,9 @@ class DiffusionModel(nn.Module): global_cond=global_cond, ) # Compute previous image: x_t -> x_t-1 - sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample + sample = self.noise_scheduler.step( + model_output, t, sample, generator=generator + ).prev_sample return sample @@ -244,27 +255,39 @@ class DiffusionModel(nn.Module): if self.config.image_features: if self.config.use_separate_rgb_encoder_per_camera: # Combine batch and sequence dims while rearranging to make the camera index dimension first. - images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...") + images_per_camera = einops.rearrange( + batch["observation.images"], "b s n ... -> n (b s) ..." + ) img_features_list = torch.cat( [ encoder(images) - for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True) + for encoder, images in zip( + self.rgb_encoder, images_per_camera, strict=True + ) ] ) # Separate batch and sequence dims back out. The camera index dim gets absorbed into the # feature dim (effectively concatenating the camera features). img_features = einops.rearrange( - img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps + img_features_list, + "(n b s) ... -> b s (n ...)", + b=batch_size, + s=n_obs_steps, ) else: # Combine batch, sequence, and "which camera" dims before passing to shared encoder. img_features = self.rgb_encoder( - einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...") + einops.rearrange( + batch["observation.images"], "b s n ... -> (b s n) ..." + ) ) # Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the # feature dim (effectively concatenating the camera features). img_features = einops.rearrange( - img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps + img_features, + "(b s n) ... -> b s (n ...)", + b=batch_size, + s=n_obs_steps, ) global_cond_feats.append(img_features) @@ -350,7 +373,9 @@ class DiffusionModel(nn.Module): elif self.config.prediction_type == "sample": target = batch["action"] else: - raise ValueError(f"Unsupported prediction type {self.config.prediction_type}") + raise ValueError( + f"Unsupported prediction type {self.config.prediction_type}" + ) loss = F.mse_loss(pred, target, reduction="none") @@ -410,7 +435,9 @@ class SpatialSoftmax(nn.Module): # we could use torch.linspace directly but that seems to behave slightly differently than numpy # and causes a small degradation in pc_success of pre-trained models. - pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)) + pos_x, pos_y = np.meshgrid( + np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h) + ) pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float() pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float() # register as buffer so it's moved to the correct device. @@ -452,7 +479,9 @@ class DiffusionRgbEncoder(nn.Module): # Always use center crop for eval self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape) if config.crop_is_random: - self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape) + self.maybe_random_crop = torchvision.transforms.RandomCrop( + config.crop_shape + ) else: self.maybe_random_crop = self.center_crop else: @@ -473,7 +502,9 @@ class DiffusionRgbEncoder(nn.Module): self.backbone = _replace_submodules( root_module=self.backbone, predicate=lambda x: isinstance(x, nn.BatchNorm2d), - func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features), + func=lambda x: nn.GroupNorm( + num_groups=x.num_features // 16, num_channels=x.num_features + ), ) # Set up pooling and final layers. @@ -515,7 +546,9 @@ class DiffusionRgbEncoder(nn.Module): def _replace_submodules( - root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module] + root_module: nn.Module, + predicate: Callable[[nn.Module], bool], + func: Callable[[nn.Module], nn.Module], ) -> nn.Module: """ Args: @@ -528,7 +561,11 @@ def _replace_submodules( if predicate(root_module): return func(root_module) - replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] + replace_list = [ + k.split(".") + for k, m in root_module.named_modules(remove_duplicate=True) + if predicate(m) + ] for *parents, k in replace_list: parent_module = root_module if len(parents) > 0: @@ -543,7 +580,9 @@ def _replace_submodules( else: setattr(parent_module, k, tgt_module) # verify that all BN are replaced - assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)) + assert not any( + predicate(m) for _, m in root_module.named_modules(remove_duplicate=True) + ) return root_module @@ -571,7 +610,9 @@ class DiffusionConv1dBlock(nn.Module): super().__init__() self.block = nn.Sequential( - nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), + nn.Conv1d( + inp_channels, out_channels, kernel_size, padding=kernel_size // 2 + ), nn.GroupNorm(n_groups, out_channels), nn.Mish(), ) @@ -594,9 +635,13 @@ class DiffusionConditionalUnet1d(nn.Module): # Encoder for the diffusion timestep. self.diffusion_step_encoder = nn.Sequential( DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim), - nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4), + nn.Linear( + config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4 + ), nn.Mish(), - nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim), + nn.Linear( + config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim + ), ) # The FiLM conditioning dimension. @@ -621,10 +666,16 @@ class DiffusionConditionalUnet1d(nn.Module): self.down_modules.append( nn.ModuleList( [ - DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs), - DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs), + DiffusionConditionalResidualBlock1d( + dim_in, dim_out, **common_res_block_kwargs + ), + DiffusionConditionalResidualBlock1d( + dim_out, dim_out, **common_res_block_kwargs + ), # Downsample as long as it is not the last block. - nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(), + nn.Conv1d(dim_out, dim_out, 3, 2, 1) + if not is_last + else nn.Identity(), ] ) ) @@ -633,10 +684,14 @@ class DiffusionConditionalUnet1d(nn.Module): self.mid_modules = nn.ModuleList( [ DiffusionConditionalResidualBlock1d( - config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs + config.down_dims[-1], + config.down_dims[-1], + **common_res_block_kwargs, ), DiffusionConditionalResidualBlock1d( - config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs + config.down_dims[-1], + config.down_dims[-1], + **common_res_block_kwargs, ), ] ) @@ -649,10 +704,16 @@ class DiffusionConditionalUnet1d(nn.Module): nn.ModuleList( [ # dim_in * 2, because it takes the encoder's skip connection as well - DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs), - DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs), + DiffusionConditionalResidualBlock1d( + dim_in * 2, dim_out, **common_res_block_kwargs + ), + DiffusionConditionalResidualBlock1d( + dim_out, dim_out, **common_res_block_kwargs + ), # Upsample as long as it is not the last block. - nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(), + nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) + if not is_last + else nn.Identity(), ] ) ) @@ -726,17 +787,23 @@ class DiffusionConditionalResidualBlock1d(nn.Module): self.use_film_scale_modulation = use_film_scale_modulation self.out_channels = out_channels - self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups) + self.conv1 = DiffusionConv1dBlock( + in_channels, out_channels, kernel_size, n_groups=n_groups + ) # FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale. cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels)) - self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups) + self.conv2 = DiffusionConv1dBlock( + out_channels, out_channels, kernel_size, n_groups=n_groups + ) # A final convolution for dimension matching the residual (if needed). self.residual_conv = ( - nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() + nn.Conv1d(in_channels, out_channels, 1) + if in_channels != out_channels + else nn.Identity() ) def forward(self, x: Tensor, cond: Tensor) -> Tensor: diff --git a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py index e6700547..eb023f9f 100644 --- a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py +++ b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py @@ -7,7 +7,9 @@ from torch import Tensor, nn from .configuration_classifier import ClassifierConfig -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) logger = logging.getLogger(__name__) @@ -15,7 +17,10 @@ class ClassifierOutput: """Wrapper for classifier outputs with additional metadata.""" def __init__( - self, logits: Tensor, probabilities: Optional[Tensor] = None, hidden_states: Optional[Tensor] = None + self, + logits: Tensor, + probabilities: Optional[Tensor] = None, + hidden_states: Optional[Tensor] = None, ): self.logits = logits self.probabilities = probabilities @@ -43,12 +48,14 @@ class Classifier( name = "classifier" def __init__(self, config: ClassifierConfig): - from transformers import AutoImageProcessor, AutoModel + from transformers import AutoModel super().__init__() self.config = config # self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True) - encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True) + encoder = AutoModel.from_pretrained( + self.config.model_name, trust_remote_code=True + ) # Extract vision model if we're given a multimodal model if hasattr(encoder, "vision_model"): logging.info("Multimodal model detected - using vision encoder only") @@ -74,7 +81,9 @@ class Classifier( self.feature_dim = self.encoder.fc.in_features self.encoder = nn.Sequential(*list(self.encoder.children())[:-1]) elif hasattr(self.encoder.config, "hidden_sizes"): - self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension + self.feature_dim = self.encoder.config.hidden_sizes[ + -1 + ] # Last channel dimension else: raise ValueError("Unsupported CNN architecture") @@ -94,14 +103,19 @@ class Classifier( if hasattr(self.encoder.config, "hidden_size"): input_dim = self.encoder.config.hidden_size else: - raise ValueError("Unsupported transformer architecture since hidden_size is not found") + raise ValueError( + "Unsupported transformer architecture since hidden_size is not found" + ) self.classifier_head = nn.Sequential( nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim), nn.Dropout(self.config.dropout_rate), nn.LayerNorm(self.config.hidden_dim), nn.ReLU(), - nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes), + nn.Linear( + self.config.hidden_dim, + 1 if self.config.num_classes == 2 else self.config.num_classes, + ), ) self.classifier_head = self.classifier_head.to(self.config.device) @@ -127,7 +141,10 @@ class Classifier( return features else: # Transformer models outputs = self.encoder(processed) - if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None: + if ( + hasattr(outputs, "pooler_output") + and outputs.pooler_output is not None + ): return outputs.pooler_output return outputs.last_hidden_state[:, 0, :] @@ -143,7 +160,9 @@ class Classifier( else: probabilities = torch.softmax(logits, dim=-1) - return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs) + return ClassifierOutput( + logits=logits, probabilities=probabilities, hidden_states=encoder_outputs + ) def predict_reward(self, x, threshold=0.6): if self.config.num_classes == 2: diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index db596982..9eb864ec 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -59,7 +59,9 @@ class SACPolicy( config.input_normalization_params ) self.normalize_inputs = Normalize( - config.input_shapes, config.input_normalization_modes, input_normalization_params + config.input_shapes, + config.input_normalization_modes, + input_normalization_params, ) else: self.normalize_inputs = nn.Identity() @@ -90,7 +92,8 @@ class SACPolicy( ensemble=Ensemble( [ CriticHead( - input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], + input_dim=encoder_critic.output_dim + + config.output_shapes["action"][0], **config.critic_network_kwargs, ) for _ in range(config.num_critics) @@ -104,7 +107,8 @@ class SACPolicy( ensemble=Ensemble( [ CriticHead( - input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], + input_dim=encoder_critic.output_dim + + config.output_shapes["action"][0], **config.critic_network_kwargs, ) for _ in range(config.num_critics) @@ -120,13 +124,17 @@ class SACPolicy( self.actor = Policy( encoder=encoder_actor, - network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs), + network=MLP( + input_dim=encoder_actor.output_dim, **config.actor_network_kwargs + ), action_dim=config.output_shapes["action"][0], encoder_is_shared=config.shared_encoder, **config.policy_kwargs, ) if config.target_entropy is None: - config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2) + config.target_entropy = ( + -np.prod(config.output_shapes["action"][0]) / 2 + ) # (-dim(A)/2) # TODO (azouitine): Handle the case where the temparameter is a fixed # TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise @@ -153,7 +161,11 @@ class SACPolicy( return actions def critic_forward( - self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False, observation_features: Tensor | None = None + self, + observations: dict[str, Tensor], + actions: Tensor, + use_target: bool = False, + observation_features: Tensor | None = None, ) -> Tensor: """Forward pass through a critic network ensemble @@ -173,21 +185,37 @@ class SACPolicy( def update_target_networks(self): """Update target networks with exponential moving average""" for target_param, param in zip( - self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=False + self.critic_target.parameters(), + self.critic_ensemble.parameters(), + strict=False, ): target_param.data.copy_( param.data * self.config.critic_target_update_weight + target_param.data * (1.0 - self.config.critic_target_update_weight) ) - def compute_loss_critic(self, observations, actions, rewards, next_observations, done, observation_features: Tensor | None = None, next_observation_features: Tensor | None = None) -> Tensor: + def compute_loss_critic( + self, + observations, + actions, + rewards, + next_observations, + done, + observation_features: Tensor | None = None, + next_observation_features: Tensor | None = None, + ) -> Tensor: temperature = self.log_alpha.exp().item() with torch.no_grad(): - next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features) + next_action_preds, next_log_probs, _ = self.actor( + next_observations, next_observation_features + ) # 2- compute q targets q_targets = self.critic_forward( - observations=next_observations, actions=next_action_preds, use_target=True, observation_features=next_observation_features + observations=next_observations, + actions=next_action_preds, + use_target=True, + observation_features=next_observation_features, ) # subsample critics to prevent overfitting if use high UTD (update to date) @@ -204,7 +232,12 @@ class SACPolicy( td_target = rewards + (1 - done) * self.config.discount * min_q # 3- compute predicted qs - q_preds = self.critic_forward(observations, actions, use_target=False, observation_features=observation_features) + q_preds = self.critic_forward( + observations, + actions, + use_target=False, + observation_features=observation_features, + ) # 4- Calculate loss # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. @@ -219,20 +252,31 @@ class SACPolicy( ).sum() return critics_loss - def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor: + def compute_loss_temperature( + self, observations, observation_features: Tensor | None = None + ) -> Tensor: """Compute the temperature loss""" # calculate temperature loss with torch.no_grad(): _, log_probs, _ = self.actor(observations, observation_features) - temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean() + temperature_loss = ( + -self.log_alpha.exp() * (log_probs + self.config.target_entropy) + ).mean() return temperature_loss - def compute_loss_actor(self, observations, observation_features: Tensor | None = None) -> Tensor: + def compute_loss_actor( + self, observations, observation_features: Tensor | None = None + ) -> Tensor: temperature = self.log_alpha.exp().item() actions_pi, log_probs, _ = self.actor(observations, observation_features) - q_preds = self.critic_forward(observations, actions_pi, use_target=False, observation_features=observation_features) + q_preds = self.critic_forward( + observations, + actions_pi, + use_target=False, + observation_features=observation_features, + ) min_q_preds = q_preds.min(dim=0)[0] actor_loss = ((temperature * log_probs) - min_q_preds).mean() @@ -259,7 +303,11 @@ class MLP(nn.Module): if dropout_rate is not None and dropout_rate > 0: layers.append(nn.Dropout(p=dropout_rate)) layers.append(nn.LayerNorm(hidden_dims[0])) - layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)()) + layers.append( + activations + if isinstance(activations, nn.Module) + else getattr(nn, activations)() + ) # Rest of the layers for i in range(1, len(hidden_dims)): @@ -270,7 +318,9 @@ class MLP(nn.Module): layers.append(nn.Dropout(p=dropout_rate)) layers.append(nn.LayerNorm(hidden_dims[i])) layers.append( - activations if isinstance(activations, nn.Module) else getattr(nn, activations)() + activations + if isinstance(activations, nn.Module) + else getattr(nn, activations)() ) self.net = nn.Sequential(*layers) @@ -381,7 +431,11 @@ class CriticEnsemble(nn.Module): actions = self.output_normalization(actions)["action"] actions = actions.to(device) - obs_enc = observation_features if observation_features is not None else (observations if self.encoder is None else self.encoder(observations)) + obs_enc = ( + observation_features + if observation_features is not None + else (observations if self.encoder is None else self.encoder(observations)) + ) inputs = torch.cat([obs_enc, actions], dim=-1) q_values = self.ensemble(inputs) # [num_critics, B, 1] @@ -445,7 +499,11 @@ class Policy(nn.Module): observation_features: torch.Tensor | None = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Encode observations if encoder exists - obs_enc = observation_features if observation_features is not None else (observations if self.encoder is None else self.encoder(observations)) + obs_enc = ( + observation_features + if observation_features is not None + else (observations if self.encoder is None else self.encoder(observations)) + ) # Get network outputs outputs = self.network(obs_enc) @@ -454,11 +512,15 @@ class Policy(nn.Module): # Compute standard deviations if self.fixed_std is None: log_std = self.std_layer(outputs) - assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!" + assert not torch.isnan( + log_std + ).any(), "[ERROR] log_std became NaN after std_layer!" if self.use_tanh_squash: log_std = torch.tanh(log_std) - log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0) + log_std = self.log_std_min + 0.5 * ( + self.log_std_max - self.log_std_min + ) * (log_std + 1.0) else: log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) else: @@ -471,7 +533,9 @@ class Policy(nn.Module): if self.use_tanh_squash: actions = torch.tanh(x_t) - log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh + log_probs -= torch.log( + (1 - actions.pow(2)) + 1e-6 + ) # Adjust log-probs for Tanh else: actions = x_t # No Tanh; raw Gaussian sample @@ -518,12 +582,15 @@ class SACObservationEncoder(nn.Module): freeze_image_encoder(self.image_enc_layers) else: self.parameters_to_optimize += list(self.image_enc_layers.parameters()) - self.all_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + self.all_image_keys = [ + k for k in config.input_shapes if k.startswith("observation.image") + ] if "observation.state" in config.input_shapes: self.state_enc_layers = nn.Sequential( nn.Linear( - in_features=config.input_shapes["observation.state"][0], out_features=config.latent_dim + in_features=config.input_shapes["observation.state"][0], + out_features=config.latent_dim, ), nn.LayerNorm(normalized_shape=config.latent_dim), nn.Tanh(), @@ -544,7 +611,9 @@ class SACObservationEncoder(nn.Module): self.aggregation_size += config.latent_dim self.parameters_to_optimize += list(self.env_state_enc_layers.parameters()) - self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim) + self.aggregation_layer = nn.Linear( + in_features=self.aggregation_size, out_features=config.latent_dim + ) self.parameters_to_optimize += list(self.aggregation_layer.parameters()) def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: @@ -557,13 +626,19 @@ class SACObservationEncoder(nn.Module): obs_dict = self.input_normalization(obs_dict) # Batch all images along the batch dimension, then encode them. if len(self.all_image_keys) > 0: - images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0) + images_batched = torch.cat( + [obs_dict[key] for key in self.all_image_keys], dim=0 + ) images_batched = self.image_enc_layers(images_batched) - embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys)) + embeddings_chunks = torch.chunk( + images_batched, dim=0, chunks=len(self.all_image_keys) + ) feat.extend(embeddings_chunks) if "observation.environment_state" in self.config.input_shapes: - feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) + feat.append( + self.env_state_enc_layers(obs_dict["observation.environment_state"]) + ) if "observation.state" in self.config.input_shapes: feat.append(self.state_enc_layers(obs_dict["observation.state"])) @@ -631,7 +706,9 @@ class PretrainedImageEncoder(nn.Module): def __init__(self, config): super().__init__() - self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config) + self.image_enc_layers, self.image_enc_out_shape = ( + self._load_pretrained_vision_encoder(config) + ) self.image_enc_proj = nn.Sequential( nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim), nn.LayerNorm(config.latent_dim), @@ -642,15 +719,21 @@ class PretrainedImageEncoder(nn.Module): """Set up CNN encoder""" from transformers import AutoModel - self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True) + self.image_enc_layers = AutoModel.from_pretrained( + config.vision_encoder_name, trust_remote_code=True + ) # self.image_enc_layers.pooler = Identity() if hasattr(self.image_enc_layers.config, "hidden_sizes"): - self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension + self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[ + -1 + ] # Last channel dimension elif hasattr(self.image_enc_layers, "fc"): self.image_enc_out_shape = self.image_enc_layers.fc.in_features else: - raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN") + raise ValueError( + "Unsupported vision encoder architecture, make sure you are using a CNN" + ) return self.image_enc_layers, self.image_enc_out_shape def forward(self, x): @@ -673,7 +756,7 @@ def orthogonal_init(): class Identity(nn.Module): def __init__(self): - super(Identity, self).__init__() + super().__init__() def forward(self, x): return x @@ -701,7 +784,9 @@ class Ensemble(nn.Module): return self.module(*args, **kwargs) def forward(self, *args, **kwargs): - return torch.vmap(self._call, (0, None), randomness="different")(self.params, *args, **kwargs) + return torch.vmap(self._call, (0, None), randomness="different")( + self.params, *args, **kwargs + ) def __repr__(self): return f"Vectorized {len(self)}x " + self._repr @@ -710,7 +795,9 @@ class Ensemble(nn.Module): # TODO (azouitine): I think in our case this function is not usefull we should remove it # after some investigation # borrowed from tdmpc -def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: +def flatten_forward_unflatten( + fn: Callable[[Tensor], Tensor], image_tensor: Tensor +) -> Tensor: """Helper to temporarily flatten extra dims at the start of the image tensor. Args: @@ -736,7 +823,9 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict: for key, value in inner_dict.items(): converted_params[outer_key][key] = torch.tensor(value) if "image" in outer_key: - converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1) + converted_params[outer_key][key] = converted_params[outer_key][ + key + ].view(3, 1, 1) return converted_params diff --git a/lerobot/common/policies/tdmpc/configuration_tdmpc.py b/lerobot/common/policies/tdmpc/configuration_tdmpc.py index 3fce01df..da1edfee 100644 --- a/lerobot/common/policies/tdmpc/configuration_tdmpc.py +++ b/lerobot/common/policies/tdmpc/configuration_tdmpc.py @@ -183,7 +183,9 @@ class TDMPCConfig(PreTrainedConfig): "If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1." ) if not self.use_mpc: - raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.") + raise ValueError( + "If `n_action_steps > 1`, `use_mpc` must be set to `True`." + ) if self.n_action_steps > self.horizon: raise ValueError("`n_action_steps` must be less than or equal to `horizon`.") diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index 0940f198..10e8bbcc 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -100,7 +100,9 @@ class TDMPCPolicy(PreTrainedPolicy): """ self._queues = { "observation.state": deque(maxlen=1), - "action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)), + "action": deque( + maxlen=max(self.config.n_action_steps, self.config.n_action_repeats) + ), } if self.config.image_features: self._queues["observation.image"] = deque(maxlen=1) @@ -189,7 +191,11 @@ class TDMPCPolicy(PreTrainedPolicy): # In the CEM loop we will need this for a call to estimate_value with the gaussian sampled # trajectories. - z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples) + z = einops.repeat( + z, + "b d -> n b d", + n=self.config.n_gaussian_samples + self.config.n_pi_samples, + ) # Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization # algorithm. @@ -211,35 +217,47 @@ class TDMPCPolicy(PreTrainedPolicy): self.config.action_feature.shape[0], device=std.device, ) - gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1) + gaussian_actions = torch.clamp( + mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1 + ) # Compute elite actions. actions = torch.cat([gaussian_actions, pi_actions], dim=1) value = self.estimate_value(z, actions).nan_to_num_(0) - elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch) + elite_idxs = torch.topk( + value, self.config.n_elites, dim=0 + ).indices # (n_elites, batch) elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch) # (horizon, n_elites, batch, action_dim) - elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1) + elite_actions = actions.take_along_dim( + einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1 + ) # Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites. max_value = elite_value.max(0, keepdim=True)[0] # (1, batch) # The weighting is a softmax over trajectory values. Note that this is not the same as the usage # of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This # makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²). - score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value)) + score = torch.exp( + self.config.elite_weighting_temperature * (elite_value - max_value) + ) score /= score.sum(axis=0, keepdim=True) # (horizon, batch, action_dim) - _mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) + _mean = torch.sum( + einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1 + ) _std = torch.sqrt( torch.sum( einops.rearrange(score, "n b -> n b 1") - * (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2, + * (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) + ** 2, dim=1, ) ) # Update mean with an exponential moving average, and std with a direct replacement. mean = ( - self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean + self.config.gaussian_mean_momentum * mean + + (1 - self.config.gaussian_mean_momentum) * _mean ) std = _std.clamp_(self.config.min_std, self.config.max_std) @@ -248,7 +266,9 @@ class TDMPCPolicy(PreTrainedPolicy): # Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax # scores from the last iteration. - actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)] + actions = elite_actions[ + :, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size) + ] return actions @@ -271,7 +291,8 @@ class TDMPCPolicy(PreTrainedPolicy): # of the FOWM paper. if self.config.uncertainty_regularizer_coeff > 0: regularization = -( - self.config.uncertainty_regularizer_coeff * self.model.Qs(z, actions[t]).std(0) + self.config.uncertainty_regularizer_coeff + * self.model.Qs(z, actions[t]).std(0) ) else: regularization = 0 @@ -291,15 +312,22 @@ class TDMPCPolicy(PreTrainedPolicy): if self.config.q_ensemble_size > 2: G += ( running_discount - * torch.min(terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0)[ - 0 - ] + * torch.min( + terminal_values[ + torch.randint(0, self.config.q_ensemble_size, size=(2,)) + ], + dim=0, + )[0] ) else: G += running_discount * torch.min(terminal_values, dim=0)[0] # Finally, also regularize the terminal value. if self.config.uncertainty_regularizer_coeff > 0: - G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0) + G -= ( + running_discount + * self.config.uncertainty_regularizer_coeff + * terminal_values.std(0) + ) return G def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: @@ -329,7 +357,10 @@ class TDMPCPolicy(PreTrainedPolicy): # Apply random image augmentations. if self.config.image_features and self.config.max_random_shift_ratio > 0: observations["observation.image"] = flatten_forward_unflatten( - partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio), + partial( + random_shifts_aug, + max_random_shift_ratio=self.config.max_random_shift_ratio, + ), observations["observation.image"], ) @@ -347,14 +378,20 @@ class TDMPCPolicy(PreTrainedPolicy): # Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action # gives us a next `z`. batch_size = batch["index"].shape[0] - z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device) + z_preds = torch.empty( + horizon + 1, batch_size, self.config.latent_dim, device=device + ) z_preds[0] = self.model.encode(current_observation) reward_preds = torch.empty_like(reward, device=device) for t in range(horizon): - z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t]) + z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward( + z_preds[t], action[t] + ) # Compute Q and V value predictions based on the latent rollout. - q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch) + q_preds_ensemble = self.model.Qs( + z_preds[:-1], action + ) # (ensemble, horizon, batch) v_preds = self.model.V(z_preds[:-1]) info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()}) @@ -368,10 +405,14 @@ class TDMPCPolicy(PreTrainedPolicy): # actions (not actions estimated by π). # Note: Here we do not use self.model_target, but self.model. This is to follow the original code # and the FOWM paper. - q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations)) + q_targets = reward + self.config.discount * self.model.V( + self.model.encode(next_observations) + ) # From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we # are using them to compute loss for V. - v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True) + v_targets = self.model_target.Qs( + z_preds[:-1].detach(), action, return_min=True + ) # Compute losses. # Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the @@ -414,7 +455,9 @@ class TDMPCPolicy(PreTrainedPolicy): temporal_loss_coeffs * F.mse_loss( q_preds_ensemble, - einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]), + einops.repeat( + q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0] + ), reduction="none", ).sum(0) # sum over ensemble # `q_preds_ensemble` depends on the first observation and the actions. @@ -452,12 +495,14 @@ class TDMPCPolicy(PreTrainedPolicy): z_preds = z_preds.detach() # Use stopgrad for the advantage calculation. with torch.no_grad(): - advantage = self.model_target.Qs(z_preds[:-1], action, return_min=True) - self.model.V( - z_preds[:-1] - ) + advantage = self.model_target.Qs( + z_preds[:-1], action, return_min=True + ) - self.model.V(z_preds[:-1]) info["advantage"] = advantage[0] # (t, b) - exp_advantage = torch.clamp(torch.exp(advantage * self.config.advantage_scaling), max=100.0) + exp_advantage = torch.clamp( + torch.exp(advantage * self.config.advantage_scaling), max=100.0 + ) action_preds = self.model.pi(z_preds[:-1]) # (t, b, a) # Calculate the MSE between the actions and the action predictions. # Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation @@ -511,7 +556,9 @@ class TDMPCPolicy(PreTrainedPolicy): # Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA # update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code # we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995) - update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum) + update_ema_parameters( + self.model_target, self.model, self.config.target_model_momentum + ) class TDMPCTOLD(nn.Module): @@ -598,7 +645,9 @@ class TDMPCTOLD(nn.Module): "Sanity check. The last linear layer needs 0 initialization on weights." ) nn.init.zeros_(m[-1].weight) - nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure + nn.init.zeros_( + m[-1].bias + ) # this has already been done, but keep this line here for good measure def encode(self, obs: dict[str, Tensor]) -> Tensor: """Encodes an observation into its latent representation.""" @@ -702,11 +751,26 @@ class TDMPCObservationEncoder(nn.Module): stride=2, ), nn.ReLU(), - nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2), + nn.Conv2d( + config.image_encoder_hidden_dim, + config.image_encoder_hidden_dim, + 5, + stride=2, + ), nn.ReLU(), - nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), + nn.Conv2d( + config.image_encoder_hidden_dim, + config.image_encoder_hidden_dim, + 3, + stride=2, + ), nn.ReLU(), - nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), + nn.Conv2d( + config.image_encoder_hidden_dim, + config.image_encoder_hidden_dim, + 3, + stride=2, + ), nn.ReLU(), ) dummy_shape = (1, *next(iter(config.image_features.values())).shape) @@ -796,12 +860,17 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float): """Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param.""" for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True): for (n_p_ema, p_ema), (n_p, p) in zip( - ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True + ema_module.named_parameters(recurse=False), + module.named_parameters(recurse=False), + strict=True, ): assert n_p_ema == n_p, "Parameter names don't match for EMA model update" if isinstance(p, dict): raise RuntimeError("Dict parameter not supported") - if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad: + if ( + isinstance(module, nn.modules.batchnorm._BatchNorm) + or not p.requires_grad + ): # Copy BatchNorm parameters, and non-trainable parameters directly. p_ema.copy_(p.to(dtype=p_ema.dtype).data) with torch.no_grad(): @@ -809,7 +878,9 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float): p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha) -def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: +def flatten_forward_unflatten( + fn: Callable[[Tensor], Tensor], image_tensor: Tensor +) -> Tensor: """Helper to temporarily flatten extra dims at the start of the image tensor. Args: diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index 97a08e2f..201870dd 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -145,8 +145,14 @@ class VQBeTPolicy(PreTrainedPolicy): ) if len(self._queues["action"]) == 0: - batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} - actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size] + batch = { + k: torch.stack(list(self._queues[k]), dim=1) + for k in batch + if k in self._queues + } + actions = self.vqbet(batch, rollout=True)[ + :, : self.config.action_chunk_size + ] # the dimension of returned action is (batch_size, action_chunk_size, action_dim) actions = self.unnormalize_outputs({"action": actions})["action"] @@ -168,7 +174,9 @@ class VQBeTPolicy(PreTrainedPolicy): # n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`. # n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree). loss, n_different_codes, n_different_combinations, recon_l1_error = ( - self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"]) + self.vqbet.action_head.discretize( + self.config.n_vqvae_training_steps, batch["action"] + ) ) return loss, { "n_different_codes": n_different_codes, @@ -225,7 +233,9 @@ class SpatialSoftmax(nn.Module): # we could use torch.linspace directly but that seems to behave slightly differently than numpy # and causes a small degradation in pc_success of pre-trained models. - pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)) + pos_x, pos_y = np.meshgrid( + np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h) + ) pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float() pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float() # register as buffer so it's moved to the correct device. @@ -339,7 +349,12 @@ class VQBeTModel(nn.Module): num_tokens = self.config.n_action_pred_token + self.config.n_obs_steps - 1 self.register_buffer( "select_target_actions_indices", - torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]), + torch.row_stack( + [ + torch.arange(i, i + self.config.action_chunk_size) + for i in range(num_tokens) + ] + ), ) def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]: @@ -354,7 +369,11 @@ class VQBeTModel(nn.Module): ) # Separate batch and sequence dims. img_features = einops.rearrange( - img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images + img_features, + "(b s n) ... -> b s n ...", + b=batch_size, + s=n_obs_steps, + n=self.num_images, ) # Arrange prior and current observation step tokens as shown in the class docstring. @@ -366,13 +385,19 @@ class VQBeTModel(nn.Module): input_tokens.append( self.state_projector(batch["observation.state"]) ) # (batch, obs_step, projection dims) - input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps)) + input_tokens.append( + einops.repeat( + self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps + ) + ) # Interleave tokens by stacking and rearranging. input_tokens = torch.stack(input_tokens, dim=2) input_tokens = einops.rearrange(input_tokens, "b n t d -> b (n t) d") len_additional_action_token = self.config.n_action_pred_token - 1 - future_action_tokens = self.action_token.repeat(batch_size, len_additional_action_token, 1) + future_action_tokens = self.action_token.repeat( + batch_size, len_additional_action_token, 1 + ) # add additional action query tokens for predicting future action chunks input_tokens = torch.cat([input_tokens, future_action_tokens], dim=1) @@ -391,7 +416,11 @@ class VQBeTModel(nn.Module): # Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional). if len_additional_action_token > 0: features = torch.cat( - [features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1 + [ + features[:, historical_act_pred_index], + features[:, -len_additional_action_token:], + ], + dim=1, ) else: features = features[:, historical_act_pred_index] @@ -399,13 +428,15 @@ class VQBeTModel(nn.Module): action_head_output = self.action_head(features) # if rollout, VQ-BeT don't calculate loss if rollout: - return action_head_output["predicted_action"][:, n_obs_steps - 1, :].reshape( - batch_size, self.config.action_chunk_size, -1 - ) + return action_head_output["predicted_action"][ + :, n_obs_steps - 1, : + ].reshape(batch_size, self.config.action_chunk_size, -1) # else, it calculate overall loss (bin prediction loss, and offset loss) else: output = batch["action"][:, self.select_target_actions_indices] - loss = self.action_head.loss_fn(action_head_output, output, reduction="mean") + loss = self.action_head.loss_fn( + action_head_output, output, reduction="mean" + ) return action_head_output, loss @@ -440,7 +471,9 @@ class VQBeTHead(nn.Module): else: self.map_to_cbet_preds_bin = MLP( in_channels=config.gpt_output_dim, - hidden_channels=[self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed], + hidden_channels=[ + self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed + ], ) self.map_to_cbet_preds_offset = MLP( in_channels=config.gpt_output_dim, @@ -467,7 +500,10 @@ class VQBeTHead(nn.Module): loss, metric = self.vqvae_model.vqvae_forward(actions) n_different_codes = sum( - [len(torch.unique(metric[2][:, i])) for i in range(self.vqvae_model.vqvae_num_layers)] + [ + len(torch.unique(metric[2][:, i])) + for i in range(self.vqvae_model.vqvae_num_layers) + ] ) n_different_combinations = len(torch.unique(metric[2], dim=0)) recon_l1_error = metric[0].detach().cpu().item() @@ -514,7 +550,13 @@ class VQBeTHead(nn.Module): cbet_secondary_logits = self.map_to_cbet_preds_secondary_bin( torch.cat( - (x, F.one_hot(sampled_primary_centers, num_classes=self.config.vqvae_n_embed)), + ( + x, + F.one_hot( + sampled_primary_centers, + num_classes=self.config.vqvae_n_embed, + ), + ), axis=1, ) ) @@ -522,19 +564,29 @@ class VQBeTHead(nn.Module): cbet_secondary_logits / self.config.bet_softmax_temperature, dim=-1 ) sampled_secondary_centers = einops.rearrange( - torch.multinomial(cbet_secondary_probs.view(-1, choices), num_samples=1), + torch.multinomial( + cbet_secondary_probs.view(-1, choices), num_samples=1 + ), "(NT) 1 -> NT", NT=NT, ) - sampled_centers = torch.stack((sampled_primary_centers, sampled_secondary_centers), axis=1) - cbet_logits = torch.stack([cbet_primary_logits, cbet_secondary_logits], dim=1) + sampled_centers = torch.stack( + (sampled_primary_centers, sampled_secondary_centers), axis=1 + ) + cbet_logits = torch.stack( + [cbet_primary_logits, cbet_secondary_logits], dim=1 + ) # if self.config.sequentially_select is False, bin prediction head samples primary and secondary code at once. else: cbet_logits = self.map_to_cbet_preds_bin(x) cbet_logits = einops.rearrange( - cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers + cbet_logits, + "(NT) (G C) -> (NT) G C", + G=self.vqvae_model.vqvae_num_layers, + ) + cbet_probs = torch.softmax( + cbet_logits / self.config.bet_softmax_temperature, dim=-1 ) - cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1) NT, G, choices = cbet_probs.shape sampled_centers = einops.rearrange( torch.multinomial(cbet_probs.view(-1, choices), num_samples=1), @@ -554,9 +606,17 @@ class VQBeTHead(nn.Module): sampled_offsets = sampled_offsets.sum(dim=1) with torch.no_grad(): # Get the centroids (= vectors corresponding to the codes) of each layer to pass it through RVQ decoder - return_decoder_input = self.vqvae_model.get_embeddings_from_code(sampled_centers).clone().detach() + return_decoder_input = ( + self.vqvae_model.get_embeddings_from_code(sampled_centers) + .clone() + .detach() + ) # pass the centroids through decoder to get actions. - decoded_action = self.vqvae_model.get_action_from_latent(return_decoder_input).clone().detach() + decoded_action = ( + self.vqvae_model.get_action_from_latent(return_decoder_input) + .clone() + .detach() + ) # reshaped extracted offset to match with decoded centroids sampled_offsets = einops.rearrange( sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size @@ -605,7 +665,9 @@ class VQBeTHead(nn.Module): # Figure out the loss for the actions. # First, we need to find the closest cluster center for each ground truth action. with torch.no_grad(): - state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G + state_vq, action_bins = self.vqvae_model.get_code( + action_seq + ) # action_bins: NT, G # Now we can compute the loss. @@ -628,8 +690,12 @@ class VQBeTHead(nn.Module): + cbet_loss2 * self.config.secondary_code_loss_weight ) - equal_primary_code_rate = torch.sum((action_bins[:, 0] == sampled_centers[:, 0]).int()) / (NT) - equal_secondary_code_rate = torch.sum((action_bins[:, 1] == sampled_centers[:, 1]).int()) / (NT) + equal_primary_code_rate = torch.sum( + (action_bins[:, 0] == sampled_centers[:, 0]).int() + ) / (NT) + equal_secondary_code_rate = torch.sum( + (action_bins[:, 1] == sampled_centers[:, 1]).int() + ) / (NT) action_mse_error = torch.mean((action_seq - predicted_action) ** 2) vq_action_error = torch.mean(torch.abs(action_seq - decoded_action)) @@ -643,7 +709,9 @@ class VQBeTHead(nn.Module): "classification_loss": cbet_loss.detach().cpu().item(), "offset_loss": offset_loss.detach().cpu().item(), "equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(), - "equal_secondary_code_rate": equal_secondary_code_rate.detach().cpu().item(), + "equal_secondary_code_rate": equal_secondary_code_rate.detach() + .cpu() + .item(), "vq_action_error": vq_action_error.detach().cpu().item(), "offset_action_error": offset_action_error.detach().cpu().item(), "action_error_max": action_error_max.detach().cpu().item(), @@ -668,7 +736,9 @@ class VQBeTRgbEncoder(nn.Module): # Always use center crop for eval self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape) if config.crop_is_random: - self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape) + self.maybe_random_crop = torchvision.transforms.RandomCrop( + config.crop_shape + ) else: self.maybe_random_crop = self.center_crop else: @@ -689,7 +759,9 @@ class VQBeTRgbEncoder(nn.Module): self.backbone = _replace_submodules( root_module=self.backbone, predicate=lambda x: isinstance(x, nn.BatchNorm2d), - func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features), + func=lambda x: nn.GroupNorm( + num_groups=x.num_features // 16, num_channels=x.num_features + ), ) # Set up pooling and final layers. @@ -730,7 +802,9 @@ class VQBeTRgbEncoder(nn.Module): def _replace_submodules( - root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module] + root_module: nn.Module, + predicate: Callable[[nn.Module], bool], + func: Callable[[nn.Module], nn.Module], ) -> nn.Module: """ Args: @@ -743,7 +817,11 @@ def _replace_submodules( if predicate(root_module): return func(root_module) - replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] + replace_list = [ + k.split(".") + for k, m in root_module.named_modules(remove_duplicate=True) + if predicate(m) + ] for *parents, k in replace_list: parent_module = root_module if len(parents) > 0: @@ -758,7 +836,9 @@ def _replace_submodules( else: setattr(parent_module, k, tgt_module) # verify that all BN are replaced - assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)) + assert not any( + predicate(m) for _, m in root_module.named_modules(remove_duplicate=True) + ) return root_module diff --git a/lerobot/common/policies/vqbet/vqbet_utils.py b/lerobot/common/policies/vqbet/vqbet_utils.py index 139d119e..71e85ac0 100644 --- a/lerobot/common/policies/vqbet/vqbet_utils.py +++ b/lerobot/common/policies/vqbet/vqbet_utils.py @@ -123,9 +123,15 @@ class CausalSelfAttention(nn.Module): # calculate query, key, values for all heads in batch and move head forward to be the batch dim q, k, v = self.c_attn(x).split(self.gpt_hidden_dim, dim=2) - k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs) - q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs) - v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs) + k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) @@ -133,7 +139,9 @@ class CausalSelfAttention(nn.Module): att = F.softmax(att, dim=-1) att = self.attn_dropout(att) y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + y = ( + y.transpose(1, 2).contiguous().view(B, T, C) + ) # re-assemble all head outputs side by side # output projection y = self.resid_dropout(self.c_proj(y)) @@ -189,12 +197,16 @@ class GPT(nn.Module): "ln_f": nn.LayerNorm(config.gpt_hidden_dim), } ) - self.lm_head = nn.Linear(config.gpt_hidden_dim, config.gpt_output_dim, bias=False) + self.lm_head = nn.Linear( + config.gpt_hidden_dim, config.gpt_output_dim, bias=False + ) # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper self.apply(self._init_weights) for pn, p in self.named_parameters(): if pn.endswith("c_proj.weight"): - torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer)) + torch.nn.init.normal_( + p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer) + ) # report number of parameters n_params = sum(p.numel() for p in self.parameters()) @@ -208,11 +220,17 @@ class GPT(nn.Module): ) # positional encodings that are added to the input embeddings - pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) + pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze( + 0 + ) # shape (1, t) # forward the GPT model itself - tok_emb = self.transformer.wte(input) # token embeddings of shape (b, t, gpt_hidden_dim) - pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim) + tok_emb = self.transformer.wte( + input + ) # token embeddings of shape (b, t, gpt_hidden_dim) + pos_emb = self.transformer.wpe( + pos + ) # position embeddings of shape (1, t, gpt_hidden_dim) x = self.transformer.drop(tok_emb + pos_emb) for block in self.transformer.h: x = block(x) @@ -237,7 +255,9 @@ class GPT(nn.Module): # but want to use a smaller block size for some smaller, simpler model assert gpt_block_size <= self.config.gpt_block_size self.config.gpt_block_size = gpt_block_size - self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:gpt_block_size]) + self.transformer.wpe.weight = nn.Parameter( + self.transformer.wpe.weight[:gpt_block_size] + ) for block in self.transformer.h: block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size] @@ -270,7 +290,9 @@ class GPT(nn.Module): param_dict = dict(self.named_parameters()) inter_params = decay & no_decay union_params = decay | no_decay - assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format( + assert ( + len(inter_params) == 0 + ), "parameters {} made it into both decay/no_decay sets!".format( str(inter_params) ) assert len(param_dict.keys() - union_params) == 0, ( @@ -368,8 +390,12 @@ class ResidualVQ(nn.Module): codebook_input_dim = codebook_dim * heads requires_projection = codebook_input_dim != dim - self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() - self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() + self.project_in = ( + nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() + ) + self.project_out = ( + nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() + ) self.num_quantizers = num_quantizers @@ -377,7 +403,10 @@ class ResidualVQ(nn.Module): self.layers = nn.ModuleList( [ VectorQuantize( - dim=codebook_dim, codebook_dim=codebook_dim, accept_image_fmap=accept_image_fmap, **kwargs + dim=codebook_dim, + codebook_dim=codebook_dim, + accept_image_fmap=accept_image_fmap, + **kwargs, ) for _ in range(num_quantizers) ] @@ -448,7 +477,9 @@ class ResidualVQ(nn.Module): return all_codes - def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp=None): + def forward( + self, x, indices=None, return_all_codes=False, sample_codebook_temp=None + ): """ For given input tensor x, this function will return the quantized output, the indices of the quantized output, and the loss. First, the input tensor x is projected to the codebook dimension. Then, the input tensor x is passed through Nq layers of VectorQuantize. @@ -477,13 +508,17 @@ class ResidualVQ(nn.Module): ) ce_losses = [] - should_quantize_dropout = self.training and self.quantize_dropout and not return_loss + should_quantize_dropout = ( + self.training and self.quantize_dropout and not return_loss + ) # sample a layer index at which to dropout further residual quantization # also prepare null indices and loss if should_quantize_dropout: - rand_quantize_dropout_index = randrange(self.quantize_dropout_cutoff_index, num_quant) + rand_quantize_dropout_index = randrange( + self.quantize_dropout_cutoff_index, num_quant + ) if quant_dropout_multiple_of != 1: rand_quantize_dropout_index = ( @@ -492,14 +527,23 @@ class ResidualVQ(nn.Module): - 1 ) - null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2]) - null_indices = torch.full(null_indices_shape, -1.0, device=device, dtype=torch.long) + null_indices_shape = ( + (x.shape[0], *x.shape[-2:]) + if self.accept_image_fmap + else tuple(x.shape[:2]) + ) + null_indices = torch.full( + null_indices_shape, -1.0, device=device, dtype=torch.long + ) null_loss = torch.full((1,), 0.0, device=device, dtype=x.dtype) # go through the layers for quantizer_index, layer in enumerate(self.layers): - if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index: + if ( + should_quantize_dropout + and quantizer_index > rand_quantize_dropout_index + ): all_indices.append(null_indices) all_losses.append(null_loss) continue @@ -539,7 +583,9 @@ class ResidualVQ(nn.Module): # stack all losses and indices - all_losses, all_indices = map(partial(torch.stack, dim=-1), (all_losses, all_indices)) + all_losses, all_indices = map( + partial(torch.stack, dim=-1), (all_losses, all_indices) + ) ret = (quantized_out, all_indices, all_losses) @@ -599,8 +645,12 @@ class VectorQuantize(nn.Module): codebook_input_dim = codebook_dim * heads requires_projection = codebook_input_dim != dim - self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() - self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() + self.project_in = ( + nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() + ) + self.project_out = ( + nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() + ) self.eps = eps self.commitment_weight = commitment_weight @@ -614,10 +664,14 @@ class VectorQuantize(nn.Module): self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only self.orthogonal_reg_max_codes = orthogonal_reg_max_codes - assert not (ema_update and learnable_codebook), "learnable codebook not compatible with EMA update" + assert not ( + ema_update and learnable_codebook + ), "learnable codebook not compatible with EMA update" assert 0 <= sync_update_v <= 1.0 - assert not (sync_update_v > 0.0 and not learnable_codebook), "learnable codebook must be turned on" + assert not ( + sync_update_v > 0.0 and not learnable_codebook + ), "learnable codebook must be turned on" self.sync_update_v = sync_update_v @@ -629,7 +683,9 @@ class VectorQuantize(nn.Module): ) if sync_codebook is None: - sync_codebook = distributed.is_initialized() and distributed.get_world_size() > 1 + sync_codebook = ( + distributed.is_initialized() and distributed.get_world_size() > 1 + ) codebook_kwargs = { "dim": codebook_dim, @@ -794,11 +850,17 @@ class VectorQuantize(nn.Module): # quantize again - quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs) + quantize, embed_ind, distances = self._codebook( + x, **codebook_forward_kwargs + ) if self.training: # determine code to use for commitment loss - maybe_detach = torch.detach if not self.learnable_codebook or freeze_codebook else identity + maybe_detach = ( + torch.detach + if not self.learnable_codebook or freeze_codebook + else identity + ) commit_quantize = maybe_detach(quantize) @@ -808,7 +870,9 @@ class VectorQuantize(nn.Module): if self.sync_update_v > 0.0: # (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf - quantize = quantize + self.sync_update_v * (quantize - quantize.detach()) + quantize = quantize + self.sync_update_v * ( + quantize - quantize.detach() + ) # function for calculating cross entropy loss to distance matrix # used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss @@ -841,7 +905,9 @@ class VectorQuantize(nn.Module): embed_ind = rearrange(embed_ind, "1 (b h) n -> b n h", h=heads) if self.accept_image_fmap: - embed_ind = rearrange(embed_ind, "b (h w) ... -> b h w ...", h=height, w=width) + embed_ind = rearrange( + embed_ind, "b (h w) ... -> b h w ...", h=height, w=width + ) if only_one: embed_ind = rearrange(embed_ind, "b 1 -> b") @@ -895,8 +961,12 @@ class VectorQuantize(nn.Module): num_codes = codebook.shape[-2] - if (self.orthogonal_reg_max_codes is not None) and num_codes > self.orthogonal_reg_max_codes: - rand_ids = torch.randperm(num_codes, device=device)[: self.orthogonal_reg_max_codes] + if ( + self.orthogonal_reg_max_codes is not None + ) and num_codes > self.orthogonal_reg_max_codes: + rand_ids = torch.randperm(num_codes, device=device)[ + : self.orthogonal_reg_max_codes + ] codebook = codebook[:, rand_ids] orthogonal_reg_loss = orthogonal_loss_fn(codebook) @@ -928,7 +998,9 @@ class VectorQuantize(nn.Module): # if masking, only return quantized for where mask has True if mask is not None: - quantize = torch.where(rearrange(mask, "... -> ... 1"), quantize, orig_input) + quantize = torch.where( + rearrange(mask, "... -> ... 1"), quantize, orig_input + ) return quantize, embed_ind, loss @@ -1038,7 +1110,9 @@ def sample_vectors(samples, num): def batched_sample_vectors(samples, num): - return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0) + return torch.stack( + [sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0 + ) def pad_shape(shape, size, dim=0): @@ -1089,7 +1163,9 @@ def sample_vectors_distributed(local_samples, num): all_num_samples = all_gather_sizes(local_samples, dim=0) if rank == 0: - samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum()) + samples_per_rank = sample_multinomial( + num, all_num_samples / all_num_samples.sum() + ) else: samples_per_rank = torch.empty_like(all_num_samples) @@ -1202,7 +1278,9 @@ class EuclideanCodebook(nn.Module): self.eps = eps self.threshold_ema_dead_code = threshold_ema_dead_code self.reset_cluster_size = ( - reset_cluster_size if (reset_cluster_size is not None) else threshold_ema_dead_code + reset_cluster_size + if (reset_cluster_size is not None) + else threshold_ema_dead_code ) assert callable(gumbel_sample) @@ -1213,8 +1291,14 @@ class EuclideanCodebook(nn.Module): "kmeans init is not compatible with multiple codebooks in distributed environment for now" ) - self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors - self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop + self.sample_fn = ( + sample_vectors_distributed + if use_ddp and sync_kmeans + else batched_sample_vectors + ) + self.kmeans_all_reduce_fn = ( + distributed.all_reduce if use_ddp and sync_kmeans else noop + ) self.all_reduce_fn = distributed.all_reduce if use_ddp else noop self.register_buffer("initted", torch.Tensor([not kmeans_init])) @@ -1353,7 +1437,9 @@ class EuclideanCodebook(nn.Module): distributed.all_reduce(variance_number) batch_variance = variance_number / num_vectors - self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay) + self.update_with_decay( + "batch_variance", batch_variance, self.affine_param_batch_decay + ) def replace(self, batch_samples, batch_mask): for ind, (samples, mask) in enumerate( @@ -1362,7 +1448,9 @@ class EuclideanCodebook(nn.Module): if not torch.any(mask): continue - sampled = self.sample_fn(rearrange(samples, "... -> 1 ..."), mask.sum().item()) + sampled = self.sample_fn( + rearrange(samples, "... -> 1 ..."), mask.sum().item() + ) sampled = rearrange(sampled, "1 ... -> ...") self.embed.data[ind][mask] = sampled @@ -1386,7 +1474,9 @@ class EuclideanCodebook(nn.Module): def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False): needs_codebook_dim = x.ndim < 4 sample_codebook_temp = ( - sample_codebook_temp if (sample_codebook_temp is not None) else self.sample_codebook_temp + sample_codebook_temp + if (sample_codebook_temp is not None) + else self.sample_codebook_temp ) x = x.float() @@ -1414,7 +1504,9 @@ class EuclideanCodebook(nn.Module): if self.affine_param: codebook_std = self.codebook_variance.clamp(min=1e-5).sqrt() batch_std = self.batch_variance.clamp(min=1e-5).sqrt() - embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean + embed = (embed - self.codebook_mean) * ( + batch_std / codebook_std + ) + self.batch_mean dist = -cdist(flatten, embed) @@ -1432,7 +1524,9 @@ class EuclideanCodebook(nn.Module): if self.training and self.ema_update and not freeze_codebook: if self.affine_param: - flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean + flatten = (flatten - self.batch_mean) * ( + codebook_std / batch_std + ) + self.codebook_mean if mask is not None: embed_onehot[~mask] = 0.0 @@ -1455,7 +1549,9 @@ class EuclideanCodebook(nn.Module): self.expire_codes_(x) if needs_codebook_dim: - quantize, embed_ind = tuple(rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind)) + quantize, embed_ind = tuple( + rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind) + ) dist = unpack_one(dist, ps, "h * d") diff --git a/lerobot/common/robot_devices/cameras/intelrealsense.py b/lerobot/common/robot_devices/cameras/intelrealsense.py index 7a21661a..1282007c 100644 --- a/lerobot/common/robot_devices/cameras/intelrealsense.py +++ b/lerobot/common/robot_devices/cameras/intelrealsense.py @@ -79,7 +79,9 @@ def save_image(img_array, serial_number, frame_index, images_dir): img.save(str(path), quality=100) logging.info(f"Saved image: {path}") except Exception as e: - logging.error(f"Failed to save image for camera {serial_number} frame {frame_index}: {e}") + logging.error( + f"Failed to save image for camera {serial_number} frame {frame_index}: {e}" + ) def save_images_from_cameras( @@ -157,7 +159,9 @@ def save_images_from_cameras( if time.perf_counter() - start_time > record_time_s: break - print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}") + print( + f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}" + ) frame_index += 1 finally: @@ -275,7 +279,9 @@ class IntelRealSenseCamera: f"Multiple {name} cameras have been detected. Please use their serial number to instantiate them." ) - name_to_serial_dict = {cam["name"]: cam["serial_number"] for cam in camera_infos} + name_to_serial_dict = { + cam["name"]: cam["serial_number"] for cam in camera_infos + } cam_sn = name_to_serial_dict[name] return cam_sn @@ -339,7 +345,9 @@ class IntelRealSenseCamera: actual_height = color_profile.height() # Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30) - if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3): + if self.fps is not None and not math.isclose( + self.fps, actual_fps, rel_tol=1e-3 + ): # Using `OSError` since it's a broad that encompasses issues related to device communication raise OSError( f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}." @@ -359,7 +367,9 @@ class IntelRealSenseCamera: self.is_connected = True - def read(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndarray, np.ndarray]: + def read( + self, temporary_color: str | None = None + ) -> np.ndarray | tuple[np.ndarray, np.ndarray]: """Read a frame from the camera returned in the format height x width x channels (e.g. 480 x 640 x 3) of type `np.uint8`, contrarily to the pytorch format which is float channel first. @@ -386,11 +396,15 @@ class IntelRealSenseCamera: color_frame = frame.get_color_frame() if not color_frame: - raise OSError(f"Can't capture color image from IntelRealSenseCamera({self.serial_number}).") + raise OSError( + f"Can't capture color image from IntelRealSenseCamera({self.serial_number})." + ) color_image = np.asanyarray(color_frame.get_data()) - requested_color_mode = self.color_mode if temporary_color is None else temporary_color + requested_color_mode = ( + self.color_mode if temporary_color is None else temporary_color + ) if requested_color_mode not in ["rgb", "bgr"]: raise ValueError( f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided." @@ -418,7 +432,9 @@ class IntelRealSenseCamera: if self.use_depth: depth_frame = frame.get_depth_frame() if not depth_frame: - raise OSError(f"Can't capture depth image from IntelRealSenseCamera({self.serial_number}).") + raise OSError( + f"Can't capture depth image from IntelRealSenseCamera({self.serial_number})." + ) depth_map = np.asanyarray(depth_frame.get_data()) @@ -460,7 +476,9 @@ class IntelRealSenseCamera: # TODO(rcadene, aliberts): intelrealsense has diverged compared to opencv over here num_tries += 1 time.sleep(1 / self.fps) - if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()): + if num_tries > self.fps and ( + self.thread.ident is None or not self.thread.is_alive() + ): raise Exception( "The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called." ) diff --git a/lerobot/common/robot_devices/cameras/opencv.py b/lerobot/common/robot_devices/cameras/opencv.py index f279f315..48111e97 100644 --- a/lerobot/common/robot_devices/cameras/opencv.py +++ b/lerobot/common/robot_devices/cameras/opencv.py @@ -45,10 +45,14 @@ from lerobot.common.utils.utils import capture_timestamp_utc MAX_OPENCV_INDEX = 60 -def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]: +def find_cameras( + raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False +) -> list[dict]: cameras = [] if platform.system() == "Linux": - print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports") + print( + "Linux detected. Finding available camera indices through scanning '/dev/video*' ports" + ) possible_ports = [str(port) for port in Path("/dev").glob("video*")] ports = _find_cameras(possible_ports, mock=mock) for port in ports: @@ -180,7 +184,9 @@ def save_images_from_cameras( dt_s = time.perf_counter() - now busy_wait(1 / fps - dt_s) - print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}") + print( + f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}" + ) if time.perf_counter() - start_time > record_time_s: break @@ -237,7 +243,9 @@ class OpenCVCamera: if platform.system() == "Linux": if isinstance(self.camera_index, int): self.port = Path(f"/dev/video{self.camera_index}") - elif isinstance(self.camera_index, str) and is_valid_unix_path(self.camera_index): + elif isinstance(self.camera_index, str) and is_valid_unix_path( + self.camera_index + ): self.port = Path(self.camera_index) # Retrieve the camera index from a potentially symlinked path self.camera_index = get_camera_index_from_unix_port(self.port) @@ -283,7 +291,9 @@ class OpenCVCamera: def connect(self): if self.is_connected: - raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.") + raise RobotDeviceAlreadyConnectedError( + f"OpenCVCamera({self.camera_index}) is already connected." + ) if self.mock: import tests.cameras.mock_cv2 as cv2 @@ -344,7 +354,9 @@ class OpenCVCamera: actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT) # Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30) - if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3): + if self.fps is not None and not math.isclose( + self.fps, actual_fps, rel_tol=1e-3 + ): # Using `OSError` since it's a broad that encompasses issues related to device communication raise OSError( f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}." @@ -386,7 +398,9 @@ class OpenCVCamera: if not ret: raise OSError(f"Can't capture color image from camera {self.camera_index}.") - requested_color_mode = self.color_mode if temporary_color_mode is None else temporary_color_mode + requested_color_mode = ( + self.color_mode if temporary_color_mode is None else temporary_color_mode + ) if requested_color_mode not in ["rgb", "bgr"]: raise ValueError( diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 1703a52a..e3096a0e 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -39,7 +39,9 @@ from lerobot.common.robot_devices.utils import busy_wait from lerobot.common.utils.utils import get_safe_torch_device, has_method -def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None): +def log_control_info( + robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None +): log_items = [] if episode_index is not None: log_items.append(f"ep:{episode_index}") @@ -106,7 +108,9 @@ def predict_action(observation, policy, device, use_amp): observation = copy(observation) with ( torch.inference_mode(), - torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), + torch.autocast(device_type=device.type) + if device.type == "cuda" and use_amp + else nullcontext(), ): # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension for name in observation: @@ -162,7 +166,9 @@ def init_keyboard_listener(assign_rewards=False): print("Right arrow key pressed. Exiting loop...") events["exit_early"] = True elif key == keyboard.Key.left: - print("Left arrow key pressed. Exiting loop and rerecord the last episode...") + print( + "Left arrow key pressed. Exiting loop and rerecord the last episode..." + ) events["rerecord_episode"] = True events["exit_early"] = True elif key == keyboard.Key.esc: @@ -256,7 +262,9 @@ def control_loop( raise ValueError("You need to provide a task as argument in `single_task`.") if dataset is not None and fps is not None and dataset.fps != fps: - raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).") + raise ValueError( + f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps})." + ) timestamp = 0 start_episode_t = time.perf_counter() @@ -291,7 +299,9 @@ def control_loop( if display_cameras and not is_headless(): image_keys = [key for key in observation if "image" in key] for key in image_keys: - cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) + cv2.imshow( + key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR) + ) cv2.waitKey(1) if fps is not None: @@ -361,7 +371,11 @@ def sanity_check_dataset_name(repo_id, policy_cfg): def sanity_check_dataset_robot_compatibility( - dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool, extra_features: dict = None + dataset: LeRobotDataset, + robot: Robot, + fps: int, + use_videos: bool, + extra_features: dict = None, ) -> None: features_from_robot = get_features_from_robot(robot, use_videos) if extra_features is not None: @@ -375,11 +389,14 @@ def sanity_check_dataset_robot_compatibility( mismatches = [] for field, dataset_value, present_value in fields: - diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"]) + diff = DeepDiff( + dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"] + ) if diff: mismatches.append(f"{field}: expected {present_value}, got {dataset_value}") if mismatches: raise ValueError( - "Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches) + "Dataset metadata compatibility check failed with mismatches:\n" + + "\n".join(mismatches) ) diff --git a/lerobot/common/robot_devices/motors/dynamixel.py b/lerobot/common/robot_devices/motors/dynamixel.py index 6096ceb5..4721196d 100644 --- a/lerobot/common/robot_devices/motors/dynamixel.py +++ b/lerobot/common/robot_devices/motors/dynamixel.py @@ -158,7 +158,9 @@ NUM_READ_RETRY = 10 NUM_WRITE_RETRY = 10 -def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray: +def convert_degrees_to_steps( + degrees: float | np.ndarray, models: str | list[str] +) -> np.ndarray: """This function converts the degree range to the step range for indicating motors rotation. It assumes a motor achieves a full rotation by going from -180 degree position to +180. The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation. @@ -384,7 +386,9 @@ class DynamixelMotorsBus: indices = [] for idx in tqdm.tqdm(possible_ids): try: - present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0] + present_idx = self.read_with_motor_ids( + self.motor_models, [idx], "ID", num_retry=num_retry + )[0] except ConnectionError: continue @@ -400,7 +404,9 @@ class DynamixelMotorsBus: def set_bus_baudrate(self, baudrate): present_bus_baudrate = self.port_handler.getBaudRate() if present_bus_baudrate != baudrate: - print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.") + print( + f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}." + ) self.port_handler.setBaudRate(baudrate) if self.port_handler.getBaudRate() != baudrate: @@ -421,7 +427,9 @@ class DynamixelMotorsBus: def set_calibration(self, calibration: dict[str, list]): self.calibration = calibration - def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None): + def apply_calibration_autocorrect( + self, values: np.ndarray | list, motor_names: list[str] | None + ): """This function applies the calibration, automatically detects out of range errors for motors values and attempts to correct. For more info, see docstring of `apply_calibration` and `autocorrect_calibration`. @@ -434,7 +442,9 @@ class DynamixelMotorsBus: values = self.apply_calibration(values, motor_names) return values - def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): + def apply_calibration( + self, values: np.ndarray | list, motor_names: list[str] | None + ): """Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with a "zero position" at 0 degree. @@ -509,7 +519,9 @@ class DynamixelMotorsBus: return values - def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): + def autocorrect_calibration( + self, values: np.ndarray | list, motor_names: list[str] | None + ): """This function automatically detects issues with values of motors after calibration, and correct for these issues. Some motors might have values outside of expected maximum bounds after calibration. @@ -551,15 +563,23 @@ class DynamixelMotorsBus: values[i] *= -1 # Convert from initial range to range [-180, 180] degrees - calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE - in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE) + calib_val = ( + (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE + ) + in_range = (calib_val > LOWER_BOUND_DEGREE) and ( + calib_val < UPPER_BOUND_DEGREE + ) # Solve this inequality to find the factor to shift the range into [-180, 180] degrees # values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE # - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE # (- (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= ((resolution // 2) - values[i] - homing_offset) / resolution - low_factor = (-(resolution // 2) - values[i] - homing_offset) / resolution - upp_factor = ((resolution // 2) - values[i] - homing_offset) / resolution + low_factor = ( + -(resolution // 2) - values[i] - homing_offset + ) / resolution + upp_factor = ( + (resolution // 2) - values[i] - homing_offset + ) / resolution elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: start_pos = self.calibration["start_pos"][calib_idx] @@ -567,7 +587,9 @@ class DynamixelMotorsBus: # Convert from initial range to range [0, 100] in % calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100 - in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR) + in_range = (calib_val > LOWER_BOUND_LINEAR) and ( + calib_val < UPPER_BOUND_LINEAR + ) # Solve this inequality to find the factor to shift the range into [0, 100] % # values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100 @@ -583,19 +605,27 @@ class DynamixelMotorsBus: factor = math.ceil(low_factor) if factor > upp_factor: - raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]") + raise ValueError( + f"No integer found between bounds [{low_factor=}, {upp_factor=}]" + ) else: factor = math.ceil(upp_factor) if factor > low_factor: - raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]") + raise ValueError( + f"No integer found between bounds [{low_factor=}, {upp_factor=}]" + ) if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: - out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" - in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" + out_of_range_str = ( + f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" + ) + in_range_str = ( + f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" + ) logging.warning( f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, " @@ -605,7 +635,9 @@ class DynamixelMotorsBus: # A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096. self.calibration["homing_offset"][calib_idx] += resolution * factor - def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): + def revert_calibration( + self, values: np.ndarray | list, motor_names: list[str] | None + ): """Inverse of `apply_calibration`.""" if motor_names is None: motor_names = self.motor_names @@ -644,7 +676,9 @@ class DynamixelMotorsBus: values = np.round(values).astype(np.int32) return values - def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY): + def read_with_motor_ids( + self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY + ): if self.mock: import tests.motors.mock_dynamixel_sdk as dxl else: @@ -746,7 +780,9 @@ class DynamixelMotorsBus: values = self.apply_calibration_autocorrect(values, motor_names) # log the number of seconds it took to read the data from the motors - delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names) + delta_ts_name = get_log_name( + "delta_timestamp_s", "read", data_name, motor_names + ) self.logs[delta_ts_name] = time.perf_counter() - start_time # log the utc time at which the data was received @@ -755,7 +791,9 @@ class DynamixelMotorsBus: return values - def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY): + def write_with_motor_ids( + self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY + ): if self.mock: import tests.motors.mock_dynamixel_sdk as dxl else: @@ -784,7 +822,12 @@ class DynamixelMotorsBus: f"{self.packet_handler.getTxRxResult(comm)}" ) - def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None): + def write( + self, + data_name, + values: int | float | np.ndarray, + motor_names: str | list[str] | None = None, + ): if not self.is_connected: raise RobotDeviceNotConnectedError( f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`." @@ -845,7 +888,9 @@ class DynamixelMotorsBus: ) # log the number of seconds it took to write the data to the motors - delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names) + delta_ts_name = get_log_name( + "delta_timestamp_s", "write", data_name, motor_names + ) self.logs[delta_ts_name] = time.perf_counter() - start_time # TODO(rcadene): should we log the time before sending the write command? diff --git a/lerobot/common/robot_devices/motors/feetech.py b/lerobot/common/robot_devices/motors/feetech.py index 64c7f413..0941428c 100644 --- a/lerobot/common/robot_devices/motors/feetech.py +++ b/lerobot/common/robot_devices/motors/feetech.py @@ -137,7 +137,9 @@ NUM_READ_RETRY = 20 NUM_WRITE_RETRY = 20 -def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray: +def convert_degrees_to_steps( + degrees: float | np.ndarray, models: str | list[str] +) -> np.ndarray: """This function converts the degree range to the step range for indicating motors rotation. It assumes a motor achieves a full rotation by going from -180 degree position to +180. The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation. @@ -365,7 +367,9 @@ class FeetechMotorsBus: indices = [] for idx in tqdm.tqdm(possible_ids): try: - present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0] + present_idx = self.read_with_motor_ids( + self.motor_models, [idx], "ID", num_retry=num_retry + )[0] except ConnectionError: continue @@ -381,7 +385,9 @@ class FeetechMotorsBus: def set_bus_baudrate(self, baudrate): present_bus_baudrate = self.port_handler.getBaudRate() if present_bus_baudrate != baudrate: - print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.") + print( + f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}." + ) self.port_handler.setBaudRate(baudrate) if self.port_handler.getBaudRate() != baudrate: @@ -402,7 +408,9 @@ class FeetechMotorsBus: def set_calibration(self, calibration: dict[str, list]): self.calibration = calibration - def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None): + def apply_calibration_autocorrect( + self, values: np.ndarray | list, motor_names: list[str] | None + ): """This function apply the calibration, automatically detects out of range errors for motors values and attempt to correct. For more info, see docstring of `apply_calibration` and `autocorrect_calibration`. @@ -415,7 +423,9 @@ class FeetechMotorsBus: values = self.apply_calibration(values, motor_names) return values - def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): + def apply_calibration( + self, values: np.ndarray | list, motor_names: list[str] | None + ): """Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with a "zero position" at 0 degree. @@ -489,7 +499,9 @@ class FeetechMotorsBus: return values - def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): + def autocorrect_calibration( + self, values: np.ndarray | list, motor_names: list[str] | None + ): """This function automatically detects issues with values of motors after calibration, and correct for these issues. Some motors might have values outside of expected maximum bounds after calibration. @@ -528,18 +540,26 @@ class FeetechMotorsBus: values[i] *= -1 # Convert from initial range to range [-180, 180] degrees - calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE - in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE) + calib_val = ( + (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE + ) + in_range = (calib_val > LOWER_BOUND_DEGREE) and ( + calib_val < UPPER_BOUND_DEGREE + ) # Solve this inequality to find the factor to shift the range into [-180, 180] degrees # values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE # - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE # (- HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= (HALF_TURN_DEGREE / 180 * (resolution // 2) - values[i] - homing_offset) / resolution low_factor = ( - -HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset + -HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) + - values[i] + - homing_offset ) / resolution upp_factor = ( - HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset + HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) + - values[i] + - homing_offset ) / resolution elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: @@ -548,7 +568,9 @@ class FeetechMotorsBus: # Convert from initial range to range [0, 100] in % calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100 - in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR) + in_range = (calib_val > LOWER_BOUND_LINEAR) and ( + calib_val < UPPER_BOUND_LINEAR + ) # Solve this inequality to find the factor to shift the range into [0, 100] % # values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100 @@ -564,19 +586,27 @@ class FeetechMotorsBus: factor = math.ceil(low_factor) if factor > upp_factor: - raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]") + raise ValueError( + f"No integer found between bounds [{low_factor=}, {upp_factor=}]" + ) else: factor = math.ceil(upp_factor) if factor > low_factor: - raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]") + raise ValueError( + f"No integer found between bounds [{low_factor=}, {upp_factor=}]" + ) if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: - out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" - in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" + out_of_range_str = ( + f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" + ) + in_range_str = ( + f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" + ) logging.warning( f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, " @@ -586,7 +616,9 @@ class FeetechMotorsBus: # A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096. self.calibration["homing_offset"][calib_idx] += resolution * factor - def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): + def revert_calibration( + self, values: np.ndarray | list, motor_names: list[str] | None + ): """Inverse of `apply_calibration`.""" if motor_names is None: motor_names = self.motor_names @@ -662,7 +694,9 @@ class FeetechMotorsBus: return values - def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY): + def read_with_motor_ids( + self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY + ): if self.mock: import tests.motors.mock_scservo_sdk as scs else: @@ -771,7 +805,9 @@ class FeetechMotorsBus: values = self.apply_calibration_autocorrect(values, motor_names) # log the number of seconds it took to read the data from the motors - delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names) + delta_ts_name = get_log_name( + "delta_timestamp_s", "read", data_name, motor_names + ) self.logs[delta_ts_name] = time.perf_counter() - start_time # log the utc time at which the data was received @@ -780,7 +816,9 @@ class FeetechMotorsBus: return values - def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY): + def write_with_motor_ids( + self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY + ): if self.mock: import tests.motors.mock_scservo_sdk as scs else: @@ -809,7 +847,12 @@ class FeetechMotorsBus: f"{self.packet_handler.getTxRxResult(comm)}" ) - def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None): + def write( + self, + data_name, + values: int | float | np.ndarray, + motor_names: str | list[str] | None = None, + ): if not self.is_connected: raise RobotDeviceNotConnectedError( f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`." @@ -870,7 +913,9 @@ class FeetechMotorsBus: ) # log the number of seconds it took to write the data to the motors - delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names) + delta_ts_name = get_log_name( + "delta_timestamp_s", "write", data_name, motor_names + ) self.logs[delta_ts_name] = time.perf_counter() - start_time # TODO(rcadene): should we log the time before sending the write command? diff --git a/lerobot/common/robot_devices/robots/dynamixel_calibration.py b/lerobot/common/robot_devices/robots/dynamixel_calibration.py index 98fe8754..8eb60c9d 100644 --- a/lerobot/common/robot_devices/robots/dynamixel_calibration.py +++ b/lerobot/common/robot_devices/robots/dynamixel_calibration.py @@ -24,9 +24,7 @@ from lerobot.common.robot_devices.motors.dynamixel import ( ) from lerobot.common.robot_devices.motors.utils import MotorsBus -URL_TEMPLATE = ( - "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp" -) +URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp" # The following positions are provided in nominal degree range ]-180, +180[ # For more info on these constants, see comments in the code where they get used. @@ -37,7 +35,9 @@ ROTATED_POSITION_DEGREE = 90 def assert_drive_mode(drive_mode): # `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted. if not np.all(np.isin(drive_mode, [0, 1])): - raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})") + raise ValueError( + f"`drive_mode` contains values other than 0 or 1: ({drive_mode})" + ) def apply_drive_mode(position, drive_mode): @@ -78,12 +78,16 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type ``` """ if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): - raise ValueError("To run calibration, the torque must be disabled on all motors.") + raise ValueError( + "To run calibration, the torque must be disabled on all motors." + ) print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") print("\nMove arm to zero position") - print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")) + print( + "See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero") + ) input("Press Enter to continue...") # We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed. @@ -104,10 +108,15 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type # corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view # of the previous motor in the kinetic chain. print("\nMove arm to rotated target position") - print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")) + print( + "See: " + + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated") + ) input("Press Enter to continue...") - rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models) + rotated_target_pos = convert_degrees_to_steps( + ROTATED_POSITION_DEGREE, arm.motor_models + ) # Find drive mode by rotating each motor by a quarter of a turn. # Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0). @@ -116,11 +125,15 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type # Re-compute homing offset to take into account drive mode rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode) - rotated_nearest_pos = compute_nearest_rounded_position(rotated_drived_pos, arm.motor_models) + rotated_nearest_pos = compute_nearest_rounded_position( + rotated_drived_pos, arm.motor_models + ) homing_offset = rotated_target_pos - rotated_nearest_pos print("\nMove arm to rest position") - print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")) + print( + "See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest") + ) input("Press Enter to continue...") print() diff --git a/lerobot/common/robot_devices/robots/feetech_calibration.py b/lerobot/common/robot_devices/robots/feetech_calibration.py index 2c1e7180..f3a59a0b 100644 --- a/lerobot/common/robot_devices/robots/feetech_calibration.py +++ b/lerobot/common/robot_devices/robots/feetech_calibration.py @@ -26,9 +26,7 @@ from lerobot.common.robot_devices.motors.feetech import ( ) from lerobot.common.robot_devices.motors.utils import MotorsBus -URL_TEMPLATE = ( - "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp" -) +URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp" # The following positions are provided in nominal degree range ]-180, +180[ # For more info on these constants, see comments in the code where they get used. @@ -39,7 +37,9 @@ ROTATED_POSITION_DEGREE = 90 def assert_drive_mode(drive_mode): # `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted. if not np.all(np.isin(drive_mode, [0, 1])): - raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})") + raise ValueError( + f"`drive_mode` contains values other than 0 or 1: ({drive_mode})" + ) def apply_drive_mode(position, drive_mode): @@ -140,7 +140,9 @@ def apply_offset(calib, offset): return calib -def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): +def run_arm_auto_calibration( + arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str +): if robot_type == "so100": return run_arm_auto_calibration_so100(arm, robot_type, arm_name, arm_type) elif robot_type == "moss": @@ -149,18 +151,27 @@ def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm raise ValueError(robot_type) -def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): +def run_arm_auto_calibration_so100( + arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str +): """All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms""" if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): - raise ValueError("To run calibration, the torque must be disabled on all motors.") + raise ValueError( + "To run calibration, the torque must be disabled on all motors." + ) if not (robot_type == "so100" and arm_type == "follower"): - raise NotImplementedError("Auto calibration only supports the follower of so100 arms for now.") + raise NotImplementedError( + "Auto calibration only supports the follower of so100 arms for now." + ) print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") print("\nMove arm to initial position") - print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")) + print( + "See: " + + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial") + ) input("Press Enter to continue...") # Lower the acceleration of the motors (in [0,254]) @@ -207,11 +218,16 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st print("Calibrate elbow_flex") calib["elbow_flex"] = move_to_calibrate( - arm, "elbow_flex", positive_first=False, in_between_move_hook=in_between_move_hook + arm, + "elbow_flex", + positive_first=False, + in_between_move_hook=in_between_move_hook, ) calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024) - arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex") + arm.write( + "Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex" + ) time.sleep(1) def in_between_move_hook(): @@ -239,18 +255,30 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st } arm.write("Goal_Position", list(positions.values()), list(positions.keys())) - arm.write("Goal_Position", round(calib["shoulder_lift"]["zero_pos"] - 1600), "shoulder_lift") + arm.write( + "Goal_Position", + round(calib["shoulder_lift"]["zero_pos"] - 1600), + "shoulder_lift", + ) time.sleep(2) - arm.write("Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex") + arm.write( + "Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex" + ) time.sleep(2) - arm.write("Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex") + arm.write( + "Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex" + ) time.sleep(2) arm.write("Goal_Position", round(calib["gripper"]["end_pos"]), "gripper") time.sleep(2) print("Calibrate wrist_roll") calib["wrist_roll"] = move_to_calibrate( - arm, "wrist_roll", invert_drive_mode=True, positive_first=False, while_move_hook=while_move_hook + arm, + "wrist_roll", + invert_drive_mode=True, + positive_first=False, + while_move_hook=while_move_hook, ) arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"], "wrist_roll") @@ -260,7 +288,9 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex") time.sleep(1) arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 2048, "elbow_flex") - arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift") + arm.write( + "Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift" + ) time.sleep(1) arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan") time.sleep(1) @@ -289,18 +319,27 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st return calib_dict -def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): +def run_arm_auto_calibration_moss( + arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str +): """All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms""" if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): - raise ValueError("To run calibration, the torque must be disabled on all motors.") + raise ValueError( + "To run calibration, the torque must be disabled on all motors." + ) if not (robot_type == "moss" and arm_type == "follower"): - raise NotImplementedError("Auto calibration only supports the follower of moss arms for now.") + raise NotImplementedError( + "Auto calibration only supports the follower of moss arms for now." + ) print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") print("\nMove arm to initial position") - print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")) + print( + "See: " + + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial") + ) input("Press Enter to continue...") # Lower the acceleration of the motors (in [0,254]) @@ -384,8 +423,12 @@ def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex") time.sleep(1) - arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift") - arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex") + arm.write( + "Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift" + ) + arm.write( + "Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex" + ) time.sleep(2) calib_modes = [] @@ -412,7 +455,9 @@ def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str return calib_dict -def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): +def run_arm_manual_calibration( + arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str +): """This function ensures that a neural network trained on data collected on a given robot can work on another robot. For instance before calibration, setting a same goal position for each motor of two different robots will get two very different positions. But after calibration, @@ -435,12 +480,16 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a ``` """ if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): - raise ValueError("To run calibration, the torque must be disabled on all motors.") + raise ValueError( + "To run calibration, the torque must be disabled on all motors." + ) print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") print("\nMove arm to zero position") - print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")) + print( + "See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero") + ) input("Press Enter to continue...") # We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed. @@ -460,10 +509,15 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a # corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view # of the previous motor in the kinetic chain. print("\nMove arm to rotated target position") - print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")) + print( + "See: " + + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated") + ) input("Press Enter to continue...") - rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models) + rotated_target_pos = convert_degrees_to_steps( + ROTATED_POSITION_DEGREE, arm.motor_models + ) # Find drive mode by rotating each motor by a quarter of a turn. # Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0). @@ -475,7 +529,9 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a homing_offset = rotated_target_pos - rotated_drived_pos print("\nMove arm to rest position") - print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")) + print( + "See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest") + ) input("Press Enter to continue...") print() diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py index 05ced833..334100ca 100644 --- a/lerobot/common/robot_devices/robots/manipulator.py +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -31,11 +31,16 @@ from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig from lerobot.common.robot_devices.robots.utils import get_arm_id -from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError +from lerobot.common.robot_devices.utils import ( + RobotDeviceAlreadyConnectedError, + RobotDeviceNotConnectedError, +) def ensure_safe_goal_position( - goal_pos: torch.Tensor, present_pos: torch.Tensor, max_relative_target: float | list[float] + goal_pos: torch.Tensor, + present_pos: torch.Tensor, + max_relative_target: float | list[float], ): # Cap relative action target magnitude for safety. diff = goal_pos - present_pos @@ -277,7 +282,9 @@ class ManipulatorRobot: # to squeeze the gripper and have it spring back to an open position on its own. for name in self.leader_arms: self.leader_arms[name].write("Torque_Enable", 1, "gripper") - self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper") + self.leader_arms[name].write( + "Goal_Position", self.config.gripper_open_degree, "gripper" + ) # Check both arms can be read for name in self.follower_arms: @@ -309,18 +316,26 @@ class ManipulatorRobot: print(f"Missing calibration file '{arm_calib_path}'") if self.robot_type in ["koch", "koch_bimanual", "aloha"]: - from lerobot.common.robot_devices.robots.dynamixel_calibration import run_arm_calibration + from lerobot.common.robot_devices.robots.dynamixel_calibration import ( + run_arm_calibration, + ) - calibration = run_arm_calibration(arm, self.robot_type, name, arm_type) + calibration = run_arm_calibration( + arm, self.robot_type, name, arm_type + ) elif self.robot_type in ["so100", "moss", "lekiwi"]: from lerobot.common.robot_devices.robots.feetech_calibration import ( run_arm_manual_calibration, ) - calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type) + calibration = run_arm_manual_calibration( + arm, self.robot_type, name, arm_type + ) - print(f"Calibration is done! Saving calibration file '{arm_calib_path}'") + print( + f"Calibration is done! Saving calibration file '{arm_calib_path}'" + ) arm_calib_path.parent.mkdir(parents=True, exist_ok=True) with open(arm_calib_path, "w") as f: json.dump(calibration, f) @@ -339,13 +354,17 @@ class ManipulatorRobot: from lerobot.common.robot_devices.motors.dynamixel import TorqueMode if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): - raise ValueError("To run set robot preset, the torque must be disabled on all motors.") + raise ValueError( + "To run set robot preset, the torque must be disabled on all motors." + ) # Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't # rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm, # you could end up with a servo with a position 0 or 4095 at a crucial point See [ # https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11] - all_motors_except_gripper = [name for name in arm.motor_names if name != "gripper"] + all_motors_except_gripper = [ + name for name in arm.motor_names if name != "gripper" + ] if len(all_motors_except_gripper) > 0: # 4 corresponds to Extended Position on Koch motors arm.write("Operating_Mode", 4, all_motors_except_gripper) @@ -374,7 +393,9 @@ class ManipulatorRobot: # Enable torque on the gripper of the leader arms, and move it to 45 degrees, # so that we can use it as a trigger to close the gripper of the follower arms. self.leader_arms[name].write("Torque_Enable", 1, "gripper") - self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper") + self.leader_arms[name].write( + "Goal_Position", self.config.gripper_open_degree, "gripper" + ) def set_aloha_robot_preset(self): def set_shadow_(arm): @@ -404,11 +425,15 @@ class ManipulatorRobot: # you could end up with a servo with a position 0 or 4095 at a crucial point See [ # https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11] all_motors_except_gripper = [ - name for name in self.follower_arms[name].motor_names if name != "gripper" + name + for name in self.follower_arms[name].motor_names + if name != "gripper" ] if len(all_motors_except_gripper) > 0: # 4 corresponds to Extended Position on Aloha motors - self.follower_arms[name].write("Operating_Mode", 4, all_motors_except_gripper) + self.follower_arms[name].write( + "Operating_Mode", 4, all_motors_except_gripper + ) # Use 'position control current based' for follower gripper to be limited by the limit of the current. # It can grasp an object without forcing too much even tho, @@ -456,7 +481,9 @@ class ManipulatorRobot: before_lread_t = time.perf_counter() leader_pos[name] = self.leader_arms[name].read("Present_Position") leader_pos[name] = torch.from_numpy(leader_pos[name]) - self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t + self.logs[f"read_leader_{name}_pos_dt_s"] = ( + time.perf_counter() - before_lread_t + ) # Send goal position to the follower follower_goal_pos = {} @@ -477,14 +504,18 @@ class ManipulatorRobot: if self.config.max_relative_target is not None: present_pos = self.follower_arms[name].read("Present_Position") present_pos = torch.from_numpy(present_pos) - goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target) + goal_pos = ensure_safe_goal_position( + goal_pos, present_pos, self.config.max_relative_target + ) # Used when record_data=True follower_goal_pos[name] = goal_pos goal_pos = goal_pos.numpy().astype(np.float32) self.follower_arms[name].write("Goal_Position", goal_pos) - self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t + self.logs[f"write_follower_{name}_goal_pos_dt_s"] = ( + time.perf_counter() - before_fwrite_t + ) # Early exit when recording data is not requested if not record_data: @@ -497,7 +528,9 @@ class ManipulatorRobot: before_fread_t = time.perf_counter() follower_pos[name] = self.follower_arms[name].read("Present_Position") follower_pos[name] = torch.from_numpy(follower_pos[name]) - self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t + self.logs[f"read_follower_{name}_pos_dt_s"] = ( + time.perf_counter() - before_fread_t + ) # Create state by concatenating follower current position state = [] @@ -519,8 +552,12 @@ class ManipulatorRobot: before_camread_t = time.perf_counter() images[name] = self.cameras[name].async_read() images[name] = torch.from_numpy(images[name]) - self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] - self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t + self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[ + "delta_timestamp_s" + ] + self.logs[f"async_read_camera_{name}_dt_s"] = ( + time.perf_counter() - before_camread_t + ) # Populate output dictionaries obs_dict, action_dict = {}, {} @@ -544,7 +581,9 @@ class ManipulatorRobot: before_fread_t = time.perf_counter() follower_pos[name] = self.follower_arms[name].read("Present_Position") follower_pos[name] = torch.from_numpy(follower_pos[name]) - self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t + self.logs[f"read_follower_{name}_pos_dt_s"] = ( + time.perf_counter() - before_fread_t + ) # Create state by concatenating follower current position state = [] @@ -559,8 +598,12 @@ class ManipulatorRobot: before_camread_t = time.perf_counter() images[name] = self.cameras[name].async_read() images[name] = torch.from_numpy(images[name]) - self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] - self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t + self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[ + "delta_timestamp_s" + ] + self.logs[f"async_read_camera_{name}_dt_s"] = ( + time.perf_counter() - before_camread_t + ) # Populate output dictionaries and format to pytorch obs_dict = {} @@ -606,7 +649,9 @@ class ManipulatorRobot: if self.config.max_relative_target is not None: present_pos = self.follower_arms[name].read("Present_Position") present_pos = torch.from_numpy(present_pos) - goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target) + goal_pos = ensure_safe_goal_position( + goal_pos, present_pos, self.config.max_relative_target + ) # Save tensor to concat and return action_sent.append(goal_pos) diff --git a/lerobot/common/robot_devices/robots/stretch.py b/lerobot/common/robot_devices/robots/stretch.py index 9cfe6e49..813732f0 100644 --- a/lerobot/common/robot_devices/robots/stretch.py +++ b/lerobot/common/robot_devices/robots/stretch.py @@ -52,7 +52,9 @@ class StretchRobot(StretchAPI): def connect(self) -> None: self.is_connected = self.startup() if not self.is_connected: - print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'") + print( + "Another process is already using Stretch. Try running 'stretch_free_robot_process.py'" + ) raise ConnectionError() for name in self.cameras: @@ -60,7 +62,9 @@ class StretchRobot(StretchAPI): self.is_connected = self.is_connected and self.cameras[name].is_connected if not self.is_connected: - print("Could not connect to the cameras, check that all cameras are plugged-in.") + print( + "Could not connect to the cameras, check that all cameras are plugged-in." + ) raise ConnectionError() self.run_calibration() @@ -105,8 +109,12 @@ class StretchRobot(StretchAPI): before_camread_t = time.perf_counter() images[name] = self.cameras[name].async_read() images[name] = torch.from_numpy(images[name]) - self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] - self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t + self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[ + "delta_timestamp_s" + ] + self.logs[f"async_read_camera_{name}_dt_s"] = ( + time.perf_counter() - before_camread_t + ) # Populate output dictionaries obs_dict, action_dict = {}, {} @@ -150,8 +158,12 @@ class StretchRobot(StretchAPI): before_camread_t = time.perf_counter() images[name] = self.cameras[name].async_read() images[name] = torch.from_numpy(images[name]) - self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] - self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t + self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[ + "delta_timestamp_s" + ] + self.logs[f"async_read_camera_{name}_dt_s"] = ( + time.perf_counter() - before_camread_t + ) # Populate output dictionaries obs_dict = {} diff --git a/lerobot/common/robot_devices/utils.py b/lerobot/common/robot_devices/utils.py index 837c9d2e..5c948e16 100644 --- a/lerobot/common/robot_devices/utils.py +++ b/lerobot/common/robot_devices/utils.py @@ -48,7 +48,8 @@ class RobotDeviceNotConnectedError(Exception): """Exception raised when the robot device is not connected.""" def __init__( - self, message="This robot device is not connected. Try calling `robot_device.connect()` first." + self, + message="This robot device is not connected. Try calling `robot_device.connect()` first.", ): self.message = message super().__init__(self.message) diff --git a/lerobot/common/utils/import_utils.py b/lerobot/common/utils/import_utils.py index cd5f8245..e2ce5a87 100644 --- a/lerobot/common/utils/import_utils.py +++ b/lerobot/common/utils/import_utils.py @@ -17,7 +17,9 @@ import importlib import logging -def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool: +def is_package_available( + pkg_name: str, return_version: bool = False +) -> tuple[bool, str] | bool: """Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py Check if the package spec exists and grab its version to avoid importing a local directory. **Note:** this doesn't work for all packages. diff --git a/lerobot/common/utils/io_utils.py b/lerobot/common/utils/io_utils.py index da0be1c7..c67d8e1e 100644 --- a/lerobot/common/utils/io_utils.py +++ b/lerobot/common/utils/io_utils.py @@ -28,7 +28,9 @@ def write_video(video_path, stacked_frames, fps): # Filter out DeprecationWarnings raised from pkg_resources with warnings.catch_warnings(): warnings.filterwarnings( - "ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning + "ignore", + "pkg_resources is deprecated as an API", + category=DeprecationWarning, ) imageio.mimsave(video_path, stacked_frames, fps=fps) diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index 984f2e38..56e151c9 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -148,7 +148,10 @@ def _relative_path_between(path1: Path, path2: Path) -> Path: except ValueError: # most likely because path1 is not a subpath of path2 common_parts = Path(osp.commonpath([path1, path2])).parts return Path( - "/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :])) + "/".join( + [".."] * (len(path2.parts) - len(common_parts)) + + list(path1.parts[len(common_parts) :]) + ) ) @@ -159,10 +162,26 @@ def print_cuda_memory_usage(): gc.collect() # Also clear the cache if you want to fully release the memory torch.cuda.empty_cache() - print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2)) - print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2)) - print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2)) - print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2)) + print( + "Current GPU Memory Allocated: {:.2f} MB".format( + torch.cuda.memory_allocated(0) / 1024**2 + ) + ) + print( + "Maximum GPU Memory Allocated: {:.2f} MB".format( + torch.cuda.max_memory_allocated(0) / 1024**2 + ) + ) + print( + "Current GPU Memory Reserved: {:.2f} MB".format( + torch.cuda.memory_reserved(0) / 1024**2 + ) + ) + print( + "Maximum GPU Memory Reserved: {:.2f} MB".format( + torch.cuda.max_memory_reserved(0) / 1024**2 + ) + ) def capture_timestamp_utc(): @@ -232,7 +251,12 @@ def is_valid_numpy_dtype_string(dtype_str: str) -> bool: class TimerManager: - def __init__(self, elapsed_time_list: list[float] | None = None, label="Elapsed time", log=True): + def __init__( + self, + elapsed_time_list: list[float] | None = None, + label="Elapsed time", + log=True, + ): self.label = label self.elapsed_time_list = elapsed_time_list self.log = log diff --git a/lerobot/configs/env/so100_real.yaml b/lerobot/configs/env/so100_real.yaml index bceeae59..1bd5cd83 100644 --- a/lerobot/configs/env/so100_real.yaml +++ b/lerobot/configs/env/so100_real.yaml @@ -9,7 +9,7 @@ env: action_dim: 6 fps: ${fps} device: mps - + wrapper: crop_params_dict: observation.images.front: [102, 43, 358, 523] @@ -28,4 +28,4 @@ env: reward_classifier: pretrained_path: outputs/classifier/13-02-random-sample-resnet10-frozen/checkpoints/best/pretrained_model config_path: lerobot/configs/policy/hilserl_classifier.yaml - \ No newline at end of file + diff --git a/lerobot/configs/policy/sac_maniskill.yaml b/lerobot/configs/policy/sac_maniskill.yaml index 87fc4095..c954b1ea 100644 --- a/lerobot/configs/policy/sac_maniskill.yaml +++ b/lerobot/configs/policy/sac_maniskill.yaml @@ -66,7 +66,7 @@ policy: observation.image: [3, 64, 64] output_shapes: action: [7] - + camera_number: 1 # Normalization / Unnormalization @@ -79,7 +79,7 @@ policy: # 1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00, # -3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00, # -6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01, - # 8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01] + # 8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01] # max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400, # 0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163, diff --git a/lerobot/scripts/configure_motor.py b/lerobot/scripts/configure_motor.py index b0dc8a97..3b395129 100644 --- a/lerobot/scripts/configure_motor.py +++ b/lerobot/scripts/configure_motor.py @@ -108,20 +108,26 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des): break if motor_index == -1: - raise ValueError("No motors detected. Please ensure you have one motor connected.") + raise ValueError( + "No motors detected. Please ensure you have one motor connected." + ) print(f"Motor index found at: {motor_index}") if brand == "feetech": # Allows ID and BAUDRATE to be written in memory - motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0) + motor_bus.write_with_motor_ids( + motor_bus.motor_models, motor_index, "Lock", 0 + ) if baudrate != baudrate_des: print(f"Setting its baudrate to {baudrate_des}") baudrate_idx = list(series_baudrate_table.values()).index(baudrate_des) # The write can fail, so we allow retries - motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx) + motor_bus.write_with_motor_ids( + motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx + ) time.sleep(0.5) motor_bus.set_bus_baudrate(baudrate_des) present_baudrate_idx = motor_bus.read_with_motor_ids( @@ -136,7 +142,9 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des): motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0) motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "ID", motor_idx_des) - present_idx = motor_bus.read_with_motor_ids(motor_bus.motor_models, motor_idx_des, "ID", num_retry=2) + present_idx = motor_bus.read_with_motor_ids( + motor_bus.motor_models, motor_idx_des, "ID", num_retry=2 + ) if present_idx != motor_idx_des: raise OSError("Failed to write index.") @@ -164,12 +172,29 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--port", type=str, required=True, help="Motors bus port (e.g. dynamixel,feetech)") - parser.add_argument("--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)") - parser.add_argument("--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)") - parser.add_argument("--ID", type=int, required=True, help="Desired ID of the current motor (e.g. 1,2,3)") parser.add_argument( - "--baudrate", type=int, default=1000000, help="Desired baudrate for the motor (default: 1000000)" + "--port", + type=str, + required=True, + help="Motors bus port (e.g. dynamixel,feetech)", + ) + parser.add_argument( + "--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)" + ) + parser.add_argument( + "--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)" + ) + parser.add_argument( + "--ID", + type=int, + required=True, + help="Desired ID of the current motor (e.g. 1,2,3)", + ) + parser.add_argument( + "--baudrate", + type=int, + default=1000000, + help="Desired baudrate for the motor (default: 1000000)", ) args = parser.parse_args() diff --git a/lerobot/scripts/control_sim_robot.py b/lerobot/scripts/control_sim_robot.py index 00ca7d8e..1eea4223 100644 --- a/lerobot/scripts/control_sim_robot.py +++ b/lerobot/scripts/control_sim_robot.py @@ -149,7 +149,11 @@ def init_sim_calibration(robot, cfg): axis_directions = np.array(cfg.get("axis_directions", [1])) offsets = np.array(cfg.get("offsets", [0])) * np.pi - return {"start_pos": start_pos, "axis_directions": axis_directions, "offsets": offsets} + return { + "start_pos": start_pos, + "axis_directions": axis_directions, + "offsets": offsets, + } def real_positions_to_sim(real_positions, axis_directions, start_pos, offsets): @@ -170,7 +174,10 @@ def teleoperate(env, robot: Robot, process_action_fn, teleop_time_s=None): leader_pos = robot.leader_arms.main.read("Present_Position") action = process_action_fn(leader_pos) env.step(np.expand_dims(action, 0)) - if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s: + if ( + teleop_time_s is not None + and time.perf_counter() - start_teleop_t > teleop_time_s + ): print("Teleoperation processes finished.") break @@ -202,19 +209,27 @@ def record( # Load pretrained policy extra_features = ( - {"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None + {"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} + if assign_rewards + else None ) policy = None if pretrained_policy_name_or_path is not None: - policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides) + policy, policy_fps, device, use_amp = init_policy( + pretrained_policy_name_or_path, policy_overrides + ) if fps is None: fps = policy_fps - logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).") + logging.warning( + f"No fps provided, so using the fps from policy config ({policy_fps})." + ) if policy is None and process_action_from_leader is None: - raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.") + raise ValueError( + "Either policy or process_action_fn has to be set to enable control in sim." + ) # initialize listener before sim env listener, events = init_keyboard_listener(assign_rewards=assign_rewards) @@ -256,7 +271,11 @@ def record( "shape": env.observation_space[obs_key].shape, } - features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None} + features["action"] = { + "dtype": "float32", + "shape": env.action_space.shape, + "names": None, + } features = {**features, **extra_features} # Create empty dataset or load existing saved episodes @@ -357,7 +376,9 @@ def record( if events["stop_recording"] or recorded_episodes >= num_episodes: break else: - logging.info("Waiting for a few seconds before starting next episode recording...") + logging.info( + "Waiting for a few seconds before starting next episode recording..." + ) busy_wait(3) log_say("Stop recording", play_sounds, blocking=True) @@ -375,7 +396,12 @@ def record( def replay( - env, root: Path, repo_id: str, episode: int, fps: int | None = None, local_files_only: bool = True + env, + root: Path, + repo_id: str, + episode: int, + fps: int | None = None, + local_files_only: bool = True, ): env = env() @@ -422,7 +448,10 @@ if __name__ == "__main__": parser_record = subparsers.add_parser("record", parents=[base_parser]) parser_record.add_argument( - "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" + "--fps", + type=none_or_int, + default=None, + help="Frames per second (set to None to disable)", ) parser_record.add_argument( "--root", @@ -448,7 +477,9 @@ if __name__ == "__main__": required=True, help="A description of the task preformed during recording that can be used as a language instruction.", ) - parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.") + parser_record.add_argument( + "--num-episodes", type=int, default=50, help="Number of episodes to record." + ) parser_record.add_argument( "--run-compute-stats", type=int, @@ -509,7 +540,10 @@ if __name__ == "__main__": parser_replay = subparsers.add_parser("replay", parents=[base_parser]) parser_replay.add_argument( - "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" + "--fps", + type=none_or_int, + default=None, + help="Frames per second (set to None to disable)", ) parser_replay.add_argument( "--root", @@ -523,7 +557,9 @@ if __name__ == "__main__": default="lerobot/test", help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).", ) - parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episodes to replay.") + parser_replay.add_argument( + "--episode", type=int, default=0, help="Index of the episodes to replay." + ) args = parser.parse_args() diff --git a/lerobot/scripts/display_sys_info.py b/lerobot/scripts/display_sys_info.py index 4d3cc291..2d844990 100644 --- a/lerobot/scripts/display_sys_info.py +++ b/lerobot/scripts/display_sys_info.py @@ -59,7 +59,11 @@ np_version = np.__version__ if HAS_NP else "N/A" torch_version = torch.__version__ if HAS_TORCH else "N/A" torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A" -cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A" +cuda_version = ( + torch._C._cuda_getCompiledVersion() + if HAS_TORCH and torch.version.cuda is not None + else "N/A" +) # TODO(aliberts): refactor into an actual command `lerobot env` @@ -77,7 +81,9 @@ def display_sys_info() -> dict: "Using GPU in script?": "", # "Using distributed or parallel set-up in script?": "", } - print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n") + print( + "\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n" + ) print(format_dict(info)) return info diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index d7a4201f..3c679530 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -170,7 +170,10 @@ def rollout( # VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't # available of none of the envs finished. if "final_info" in info: - successes = [info["is_success"] if info is not None else False for info in info["final_info"]] + successes = [ + info["is_success"] if info is not None else False + for info in info["final_info"] + ] else: successes = [False] * env.num_envs @@ -184,9 +187,13 @@ def rollout( step += 1 running_success_rate = ( - einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean() + einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any") + .numpy() + .mean() + ) + progbar.set_postfix( + {"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"} ) - progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"}) progbar.update() # Track the final observation. @@ -204,7 +211,9 @@ def rollout( if return_observations: stacked_observations = {} for key in all_observations[0]: - stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1) + stacked_observations[key] = torch.stack( + [obs[key] for obs in all_observations], dim=1 + ) ret["observation"] = stacked_observations if hasattr(policy, "use_original_modules"): @@ -266,7 +275,9 @@ def eval_policy( return n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs) if isinstance(env, gym.vector.SyncVectorEnv): - ep_frames.append(np.stack([env.envs[i].render() for i in range(n_to_render_now)])) # noqa: B023 + ep_frames.append( + np.stack([env.envs[i].render() for i in range(n_to_render_now)]) + ) # noqa: B023 elif isinstance(env, gym.vector.AsyncVectorEnv): # Here we must render all frames and discard any we don't need. ep_frames.append(np.stack(env.call("render")[:n_to_render_now])) @@ -278,7 +289,9 @@ def eval_policy( episode_data: dict | None = None # we dont want progress bar when we use slurm, since it clutters the logs - progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm()) + progbar = trange( + n_batches, desc="Stepping through eval batches", disable=inside_slurm() + ) for batch_ix in progbar: # Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout # step. @@ -289,7 +302,8 @@ def eval_policy( seeds = None else: seeds = range( - start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs) + start_seed + (batch_ix * env.num_envs), + start_seed + ((batch_ix + 1) * env.num_envs), ) rollout_data = rollout( env, @@ -307,13 +321,22 @@ def eval_policy( # Make a mask with shape (batch, n_steps) to mask out rollout data after the first done # (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step. - mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int() + mask = ( + torch.arange(n_steps) + <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps) + ).int() # Extend metrics. - batch_sum_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "sum") + batch_sum_rewards = einops.reduce( + (rollout_data["reward"] * mask), "b n -> b", "sum" + ) sum_rewards.extend(batch_sum_rewards.tolist()) - batch_max_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "max") + batch_max_rewards = einops.reduce( + (rollout_data["reward"] * mask), "b n -> b", "max" + ) max_rewards.extend(batch_max_rewards.tolist()) - batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any") + batch_successes = einops.reduce( + (rollout_data["success"] * mask), "b n -> b", "any" + ) all_successes.extend(batch_successes.tolist()) if seeds: all_seeds.extend(seeds) @@ -326,17 +349,27 @@ def eval_policy( rollout_data, done_indices, start_episode_index=batch_ix * env.num_envs, - start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)), + start_data_index=( + 0 + if episode_data is None + else (episode_data["index"][-1].item() + 1) + ), fps=env.unwrapped.metadata["render_fps"], ) if episode_data is None: episode_data = this_episode_data else: # Some sanity checks to make sure we are correctly compiling the data. - assert episode_data["episode_index"][-1] + 1 == this_episode_data["episode_index"][0] + assert ( + episode_data["episode_index"][-1] + 1 + == this_episode_data["episode_index"][0] + ) assert episode_data["index"][-1] + 1 == this_episode_data["index"][0] # Concatenate the episode data. - episode_data = {k: torch.cat([episode_data[k], this_episode_data[k]]) for k in episode_data} + episode_data = { + k: torch.cat([episode_data[k], this_episode_data[k]]) + for k in episode_data + } # Maybe render video for visualization. if max_episodes_rendered > 0 and len(ep_frames) > 0: @@ -354,7 +387,9 @@ def eval_policy( target=write_video, args=( str(video_path), - stacked_frames[: done_index + 1], # + 1 to capture the last observation + stacked_frames[ + : done_index + 1 + ], # + 1 to capture the last observation env.unwrapped.metadata["render_fps"], ), ) @@ -363,7 +398,9 @@ def eval_policy( n_episodes_rendered += 1 progbar.set_postfix( - {"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"} + { + "running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%" + } ) # Wait till all video rendering threads are done. @@ -409,7 +446,11 @@ def eval_policy( def _compile_episode_data( - rollout_data: dict, done_indices: Tensor, start_episode_index: int, start_data_index: int, fps: float + rollout_data: dict, + done_indices: Tensor, + start_episode_index: int, + start_data_index: int, + fps: float, ) -> dict: """Convenience function for `eval_policy(return_episode_data=True)` @@ -427,12 +468,16 @@ def _compile_episode_data( # Here we do `num_frames - 1` as we don't want to include the last observation frame just yet. ep_dict = { "action": rollout_data["action"][ep_ix, : num_frames - 1], - "episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)), + "episode_index": torch.tensor( + [start_episode_index + ep_ix] * (num_frames - 1) + ), "frame_index": torch.arange(0, num_frames - 1, 1), "timestamp": torch.arange(0, num_frames - 1, 1) / fps, "next.done": rollout_data["done"][ep_ix, : num_frames - 1], "next.success": rollout_data["success"][ep_ix, : num_frames - 1], - "next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32), + "next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type( + torch.float32 + ), } # For the last observation frame, all other keys will just be copy padded. @@ -448,7 +493,9 @@ def _compile_episode_data( for key in ep_dicts[0]: data_dict[key] = torch.cat([x[key] for x in ep_dicts]) - data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1) + data_dict["index"] = torch.arange( + start_data_index, start_data_index + total_frames, 1 + ) return data_dict diff --git a/lerobot/scripts/eval_on_robot.py b/lerobot/scripts/eval_on_robot.py index 842c1a28..8a7062e7 100644 --- a/lerobot/scripts/eval_on_robot.py +++ b/lerobot/scripts/eval_on_robot.py @@ -46,7 +46,11 @@ import torch from tqdm import trange from lerobot.common.policies.policy_protocol import Policy -from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position +from lerobot.common.robot_devices.control_utils import ( + busy_wait, + is_headless, + reset_follower_position, +) from lerobot.common.robot_devices.robots.factory import Robot, make_robot from lerobot.common.utils.utils import ( init_hydra_config, @@ -60,13 +64,19 @@ def get_classifier(pretrained_path, config_path): return from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg - from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig - from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier + from lerobot.common.policies.hilserl.classifier.configuration_classifier import ( + ClassifierConfig, + ) + from lerobot.common.policies.hilserl.classifier.modeling_classifier import ( + Classifier, + ) cfg = init_hydra_config(config_path) classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg) - classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths + classifier_config.num_cameras = len( + cfg.training.image_keys + ) # TODO automate these paths model = Classifier(classifier_config) model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict()) model = model.to("mps") @@ -151,11 +161,17 @@ def rollout( images = [] for key in image_keys: if display_cameras: - cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) + cv2.imshow( + key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR) + ) cv2.waitKey(1) images.append(observation[key].to("mps")) - reward = reward_classifier.predict_reward(images) if reward_classifier is not None else 0.0 + reward = ( + reward_classifier.predict_reward(images) + if reward_classifier is not None + else 0.0 + ) all_rewards.append(reward) # print("REWARD : ", reward) @@ -219,11 +235,19 @@ def eval_policy( start_eval = time.perf_counter() progbar = trange(n_episodes, desc="Evaluating policy on real robot") - reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file) + reward_classifier = get_classifier( + reward_classifier_pretrained_path, reward_classifier_config_file + ) for _ in progbar: rollout_data = rollout( - robot, policy, reward_classifier, fps, control_time_s, use_amp, display_cameras + robot, + policy, + reward_classifier, + fps, + control_time_s, + use_amp, + display_cameras, ) rollouts.append(rollout_data) @@ -289,7 +313,9 @@ def init_keyboard_listener(): print("Right arrow key pressed. Exiting loop...") events["exit_early"] = True elif key == keyboard.Key.left: - print("Left arrow key pressed. Exiting loop and rerecord the last episode...") + print( + "Left arrow key pressed. Exiting loop and rerecord the last episode..." + ) events["rerecord_episode"] = True events["exit_early"] = True elif key == keyboard.Key.space: @@ -301,7 +327,10 @@ def init_keyboard_listener(): "Place the leader in similar pose to the follower and press space again." ) events["pause_policy"] = True - log_say("Human intervention stage. Get ready to take over.", play_sounds=True) + log_say( + "Human intervention stage. Get ready to take over.", + play_sounds=True, + ) else: events["human_intervention_step"] = True print("Space key pressed. Human intervention starting.") @@ -351,7 +380,9 @@ if __name__ == "__main__": "debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)." ), ) - parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.") + parser.add_argument( + "--revision", help="Optionally provide the Hugging Face Hub revision ID." + ) parser.add_argument( "--out-dir", help=( @@ -360,7 +391,8 @@ if __name__ == "__main__": ), ) parser.add_argument( - "--display-cameras", help=("Whether to display the camera feed while the rollout is happening") + "--display-cameras", + help=("Whether to display the camera feed while the rollout is happening"), ) parser.add_argument( "--reward-classifier-pretrained-path", diff --git a/lerobot/scripts/find_motors_bus_port.py b/lerobot/scripts/find_motors_bus_port.py index 68f2315d..ca56bf48 100644 --- a/lerobot/scripts/find_motors_bus_port.py +++ b/lerobot/scripts/find_motors_bus_port.py @@ -45,9 +45,13 @@ def find_port(): print(f"The port of this MotorsBus is '{port}'") print("Reconnect the USB cable.") elif len(ports_diff) == 0: - raise OSError(f"Could not detect the port. No difference was found ({ports_diff}).") + raise OSError( + f"Could not detect the port. No difference was found ({ports_diff})." + ) else: - raise OSError(f"Could not detect the port. More than one port was found ({ports_diff}).") + raise OSError( + f"Could not detect the port. More than one port was found ({ports_diff})." + ) if __name__ == "__main__": diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 905157f1..f93b40ca 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -import random from typing import Any, Callable, Optional, Sequence, TypedDict import io @@ -737,7 +736,6 @@ def concatenate_batch_transitions( if __name__ == "__main__": - import numpy as np from tempfile import TemporaryDirectory # ===== Test 1: Create and use a synthetic ReplayBuffer ===== @@ -1139,7 +1137,7 @@ if __name__ == "__main__": savings_percent = (std_mem - opt_mem) / std_mem * 100 - print(f"\nMemory optimization result:") + print("\nMemory optimization result:") print(f"- Standard buffer state memory: {std_mem / (1024 * 1024):.2f} MB") print(f"- Optimized buffer state memory: {opt_mem / (1024 * 1024):.2f} MB") print(f"- Memory savings for state tensors: {savings_percent:.1f}%") diff --git a/lerobot/scripts/server/crop_dataset_roi.py b/lerobot/scripts/server/crop_dataset_roi.py index fb9077c9..8bb414fe 100644 --- a/lerobot/scripts/server/crop_dataset_roi.py +++ b/lerobot/scripts/server/crop_dataset_roi.py @@ -225,7 +225,9 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset( if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.") + parser = argparse.ArgumentParser( + description="Crop rectangular ROIs from a LeRobot dataset." + ) parser.add_argument( "--repo-id", type=str, @@ -247,7 +249,9 @@ if __name__ == "__main__": args = parser.parse_args() local_files_only = args.root is not None - dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=local_files_only) + dataset = LeRobotDataset( + repo_id=args.repo_id, root=args.root, local_files_only=local_files_only + ) images = get_image_from_lerobot_dataset(dataset) images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()} @@ -256,7 +260,7 @@ if __name__ == "__main__": if args.crop_params_path is None: rois = select_square_roi_for_images(images) else: - with open(args.crop_params_path, "r") as f: + with open(args.crop_params_path) as f: rois = json.load(f) # rois = { diff --git a/lerobot/scripts/server/find_joint_limits.py b/lerobot/scripts/server/find_joint_limits.py index 1c2443d6..d5870027 100644 --- a/lerobot/scripts/server/find_joint_limits.py +++ b/lerobot/scripts/server/find_joint_limits.py @@ -31,7 +31,9 @@ def find_joint_bounds( if display_cameras and not is_headless(): image_keys = [key for key in observation if "image" in key] for key in image_keys: - cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) + cv2.imshow( + key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR) + ) cv2.waitKey(1) timestamp = time.perf_counter() - start_episode_t @@ -57,7 +59,12 @@ if __name__ == "__main__": nargs="*", help="Any key=value arguments to override config values (use dots for.nested=overrides)", ) - parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds") + parser.add_argument( + "--control-time-s", + type=float, + default=20, + help="Maximum episode length in seconds", + ) args = parser.parse_args() robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides) diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index baba99e7..4bab9ac2 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -146,7 +146,7 @@ def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None: def initialize_replay_buffer( - cfg: DictConfig, logger: Logger, device: str, storage_device:str + cfg: DictConfig, logger: Logger, device: str, storage_device: str ) -> ReplayBuffer: if not cfg.resume: return ReplayBuffer( diff --git a/lerobot/scripts/server/maniskill_manipulator.py b/lerobot/scripts/server/maniskill_manipulator.py index e1c0840a..b9c9d216 100644 --- a/lerobot/scripts/server/maniskill_manipulator.py +++ b/lerobot/scripts/server/maniskill_manipulator.py @@ -10,7 +10,9 @@ from typing import Any from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv -def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, torch.Tensor]: +def preprocess_maniskill_observation( + observations: dict[str, np.ndarray], +) -> dict[str, torch.Tensor]: """Convert environment observation to LeRobot format observation. Args: observation: Dictionary of observation batches from a Gym vector environment. @@ -62,7 +64,9 @@ class ManiSkillCompat(gym.Wrapper): new_action_space_shape = env.action_space.shape[-1] new_low = np.squeeze(env.action_space.low, axis=0) new_high = np.squeeze(env.action_space.high, axis=0) - self.action_space = gym.spaces.Box(low=new_low, high=new_high, shape=(new_action_space_shape,)) + self.action_space = gym.spaces.Box( + low=new_low, high=new_high, shape=(new_action_space_shape,) + ) def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None @@ -81,7 +85,9 @@ class ManiSkillCompat(gym.Wrapper): class ManiSkillActionWrapper(gym.ActionWrapper): def __init__(self, env): super().__init__(env) - self.action_space = gym.spaces.Tuple(spaces=(env.action_space, gym.spaces.Discrete(2))) + self.action_space = gym.spaces.Tuple( + spaces=(env.action_space, gym.spaces.Discrete(2)) + ) def action(self, action): action, telop = action @@ -95,7 +101,9 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper): action_space_agent: gym.spaces.Box = env.action_space[0] action_space_agent.low = action_space_agent.low * multiply_factor action_space_agent.high = action_space_agent.high * multiply_factor - self.action_space = gym.spaces.Tuple(spaces=(action_space_agent, gym.spaces.Discrete(2))) + self.action_space = gym.spaces.Tuple( + spaces=(action_space_agent, gym.spaces.Discrete(2)) + ) def step(self, action): if isinstance(action, tuple): @@ -137,7 +145,9 @@ def make_maniskill( env = ManiSkillObservationWrapper(env, device=cfg.env.device) env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False) - env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env) + env._max_episode_steps = env.max_episode_steps = ( + 50 # gym_utils.find_max_episode_steps_value(env) + ) env.unwrapped.metadata["render_fps"] = 20 env = ManiSkillCompat(env) env = ManiSkillActionWrapper(env) @@ -149,10 +159,11 @@ def make_maniskill( if __name__ == "__main__": import argparse import hydra - from omegaconf import OmegaConf parser = argparse.ArgumentParser() - parser.add_argument("--config", type=str, default="lerobot/configs/env/maniskill_example.yaml") + parser.add_argument( + "--config", type=str, default="lerobot/configs/env/maniskill_example.yaml" + ) args = parser.parse_args() # Initialize config diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 04e6c5b4..5dee312b 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -73,7 +73,9 @@ def make_optimizer_and_scheduler(cfg, policy): }, ] optimizer = torch.optim.AdamW( - optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay + optimizer_params_dicts, + lr=cfg.training.lr, + weight_decay=cfg.training.weight_decay, ) lr_scheduler = None elif cfg.policy.name == "diffusion": @@ -100,14 +102,23 @@ def make_optimizer_and_scheduler(cfg, policy): optimizer = torch.optim.Adam( [ {"params": policy.actor.parameters(), "lr": policy.config.actor_lr}, - {"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr}, - {"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr}, + { + "params": policy.critic_ensemble.parameters(), + "lr": policy.config.critic_lr, + }, + { + "params": policy.temperature.parameters(), + "lr": policy.config.temperature_lr, + }, ] ) lr_scheduler = None elif cfg.policy.name == "vqbet": - from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler + from lerobot.common.policies.vqbet.modeling_vqbet import ( + VQBeTOptimizer, + VQBeTScheduler, + ) optimizer = VQBeTOptimizer(policy, cfg) lr_scheduler = VQBeTScheduler(optimizer, cfg) @@ -214,7 +225,9 @@ def train(cfg: TrainPipelineConfig): if cfg.resume: step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler) - num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) + num_learnable_params = sum( + p.numel() for p in policy.parameters() if p.requires_grad + ) num_total_params = sum(p.numel() for p in policy.parameters()) logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") diff --git a/lerobot/scripts/train_hilserl_classifier.py b/lerobot/scripts/train_hilserl_classifier.py index e0e01a5d..6044b038 100644 --- a/lerobot/scripts/train_hilserl_classifier.py +++ b/lerobot/scripts/train_hilserl_classifier.py @@ -14,7 +14,6 @@ import logging import time from contextlib import nullcontext -from pathlib import Path from pprint import pformat import hydra @@ -28,14 +27,16 @@ from termcolor import colored from torch import optim from torch.autograd import profiler from torch.cuda.amp import GradScaler -from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler, random_split +from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler from tqdm import tqdm from lerobot.common.datasets.factory import resolve_delta_timestamps from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.logger import Logger from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg -from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig +from lerobot.common.policies.hilserl.classifier.configuration_classifier import ( + ClassifierConfig, +) from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier from lerobot.common.utils.utils import ( format_big_number, @@ -50,7 +51,11 @@ def get_model(cfg, logger): # noqa I001 classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg) model = Classifier(classifier_config) if cfg.resume: - model.load_state_dict(Classifier.from_pretrained(str(logger.last_pretrained_model_dir)).state_dict()) + model.load_state_dict( + Classifier.from_pretrained( + str(logger.last_pretrained_model_dir) + ).state_dict() + ) return model @@ -62,7 +67,9 @@ def create_balanced_sampler(dataset, cfg): class_weights = 1.0 / counts.float() sample_weights = class_weights[labels] - return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True) + return WeightedRandomSampler( + weights=sample_weights, num_samples=len(sample_weights), replacement=True + ) def support_amp(device: torch.device, cfg: DictConfig) -> bool: @@ -71,7 +78,9 @@ def support_amp(device: torch.device, cfg: DictConfig) -> bool: return cfg.training.use_amp and device.type in ("cuda", "cpu") -def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg): +def train_epoch( + model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg +): # Single epoch training loop with AMP support and progress tracking model.train() correct = 0 @@ -85,7 +94,11 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, labels = batch[cfg.training.label_key].float().to(device) # Forward pass with optional AMP - with torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(): + with ( + torch.autocast(device_type=device.type) + if support_amp(device, cfg) + else nullcontext() + ): outputs = model(images) loss = criterion(outputs.logits, labels) @@ -130,7 +143,9 @@ def validate(model, val_loader, criterion, device, logger, cfg): with ( torch.no_grad(), - torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(), + torch.autocast(device_type=device.type) + if support_amp(device, cfg) + else nullcontext(), ): for batch in tqdm(val_loader, desc="Validation"): images = [batch[img_key].to(device) for img_key in cfg.training.image_keys] @@ -143,7 +158,9 @@ def validate(model, val_loader, criterion, device, logger, cfg): ): outputs = model(images) inference_times.append( - next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time + next( + x for x in prof.key_averages() if x.key == "model_inference" + ).cpu_time ) else: outputs = model(images) @@ -161,16 +178,24 @@ def validate(model, val_loader, criterion, device, logger, cfg): # Log sample predictions for visualization if len(samples) < cfg.eval.num_samples_to_log: - for i in range(min(cfg.eval.num_samples_to_log - len(samples), len(images))): + for i in range( + min(cfg.eval.num_samples_to_log - len(samples), len(images)) + ): if model.config.num_classes == 2: confidence = round(outputs.probabilities[i].item(), 3) else: - confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()] + confidence = [ + round(prob, 3) for prob in outputs.probabilities[i].tolist() + ] samples.append( { **{ - f"image_{img_key}": wandb.Image(images[img_idx][i].cpu()) - for img_idx, img_key in enumerate(cfg.training.image_keys) + f"image_{img_key}": wandb.Image( + images[img_idx][i].cpu() + ) + for img_idx, img_key in enumerate( + cfg.training.image_keys + ) }, "true_label": labels[i].item(), "predicted": predictions[i].item(), @@ -238,15 +263,24 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step): elif device.type == "mps": torch.mps.synchronize() - with profiler.profile(record_shapes=True) as prof, profiler.record_function("model_inference"): + with ( + profiler.profile(record_shapes=True) as prof, + profiler.record_function("model_inference"), + ): _ = model(x) inference_times.append( - next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time + next( + x for x in prof.key_averages() if x.key == "model_inference" + ).cpu_time ) inference_times = np.array(inference_times) - avg, median, std = inference_times.mean(), np.median(inference_times), inference_times.std() + avg, median, std = ( + inference_times.mean(), + np.median(inference_times), + inference_times.std(), + ) print( f"Inference time mean: {avg:.2f} us, median: {median:.2f} us, std: {std:.2f} us, with {iters} iterations on {device.type} device" ) @@ -264,7 +298,11 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step): return avg, median, std -@hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier") +@hydra.main( + version_base="1.2", + config_path="../configs/policy", + config_name="hilserl_classifier", +) def train(cfg: DictConfig) -> None: # Main training pipeline with support for resuming training logging.info(OmegaConf.to_yaml(cfg)) @@ -278,7 +316,9 @@ def train(cfg: DictConfig) -> None: # Setup dataset and dataloaders dataset = LeRobotDataset( - cfg.dataset_repo_id, root=cfg.dataset_root, local_files_only=cfg.local_files_only + cfg.dataset_repo_id, + root=cfg.dataset_root, + local_files_only=cfg.local_files_only, ) logging.info(f"Dataset size: {len(dataset)}") @@ -314,7 +354,9 @@ def train(cfg: DictConfig) -> None: "You have set resume=True, but there is no model checkpoint in " f"{Logger.get_last_checkpoint_dir(out_dir)}" ) - checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml") + checkpoint_cfg_path = str( + Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml" + ) logging.info( colored( "You have set resume=True, indicating that you wish to resume a run", @@ -327,7 +369,9 @@ def train(cfg: DictConfig) -> None: # Check for differences between the checkpoint configuration and provided configuration. # Hack to resolve the delta_timestamps ahead of time in order to properly diff. resolve_delta_timestamps(cfg) - diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)) + diff = DeepDiff( + OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg) + ) # Ignore the `resume` and parameters. if "values_changed" in diff and "root['resume']" in diff["values_changed"]: del diff["values_changed"]["root['resume']"] @@ -346,7 +390,11 @@ def train(cfg: DictConfig) -> None: optimizer = optim.AdamW(model.parameters(), lr=cfg.training.learning_rate) # Use BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multi-class - criterion = nn.BCEWithLogitsLoss() if model.config.num_classes == 2 else nn.CrossEntropyLoss() + criterion = ( + nn.BCEWithLogitsLoss() + if model.config.num_classes == 2 + else nn.CrossEntropyLoss() + ) grad_scaler = GradScaler(enabled=cfg.training.use_amp) # Log model parameters @@ -362,7 +410,17 @@ def train(cfg: DictConfig) -> None: for epoch in range(cfg.training.num_epochs): logging.info(f"\nEpoch {epoch+1}/{cfg.training.num_epochs}") - train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg) + train_epoch( + model, + train_loader, + criterion, + optimizer, + grad_scaler, + device, + logger, + step, + cfg, + ) # Periodic validation if cfg.training.eval_freq > 0 and (epoch + 1) % cfg.training.eval_freq == 0: diff --git a/lerobot/scripts/train_sac.py b/lerobot/scripts/train_sac.py index 4f7b55cc..cfd05f62 100644 --- a/lerobot/scripts/train_sac.py +++ b/lerobot/scripts/train_sac.py @@ -22,7 +22,6 @@ from typing import Callable, Optional, Sequence, TypedDict import hydra import torch import torch.nn.functional as F -from deepdiff import DeepDiff from omegaconf import DictConfig, OmegaConf from torch import nn from tqdm import tqdm @@ -30,20 +29,17 @@ from tqdm import tqdm # TODO: Remove the import of maniskill from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from lerobot.common.envs.factory import make_env, make_maniskill_env -from lerobot.common.envs.utils import preprocess_maniskill_observation, preprocess_observation +from lerobot.common.envs.factory import make_maniskill_env +from lerobot.common.envs.utils import preprocess_maniskill_observation from lerobot.common.logger import Logger, log_output_dir from lerobot.common.policies.factory import make_policy from lerobot.common.policies.sac.modeling_sac import SACPolicy -from lerobot.common.policies.utils import get_device_from_parameters from lerobot.common.utils.utils import ( format_big_number, get_safe_torch_device, - init_hydra_config, init_logging, set_global_seed, ) -from lerobot.scripts.eval import eval_policy def make_optimizers_and_scheduler(cfg, policy): @@ -56,7 +52,9 @@ def make_optimizers_and_scheduler(cfg, policy): params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr ) # We wrap policy log temperature in list because this is a torch tensor and not a nn.Module - optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr) + optimizer_temperature = torch.optim.Adam( + params=[policy.log_alpha], lr=policy.config.critic_lr + ) lr_scheduler = None optimizers = { "actor": optimizer_actor, @@ -108,7 +106,9 @@ def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Te images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C) # Gather pixels - cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :] + cropped_hwcn = images_hwcn[ + torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, : + ] # cropped_hwcn => (B, crop_h, crop_w, C) cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w) @@ -198,8 +198,12 @@ class ReplayBuffer: """ # We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from # a replay buffer than from a lerobot dataset. - replay_buffer = cls(capacity=len(lerobot_dataset), device=device, state_keys=state_keys) - list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys) + replay_buffer = cls( + capacity=len(lerobot_dataset), device=device, state_keys=state_keys + ) + list_transition = cls._lerobotdataset_to_transitions( + dataset=lerobot_dataset, state_keys=state_keys + ) # Fill the replay buffer with the lerobot dataset transitions for data in list_transition: replay_buffer.add( @@ -244,7 +248,9 @@ class ReplayBuffer: # If not provided, you can either raise an error or define a default: if state_keys is None: - raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.") + raise ValueError( + "You must provide a list of keys in `state_keys` that define your 'state'." + ) transitions: list[Transition] = [] num_frames = len(dataset) @@ -298,36 +304,40 @@ class ReplayBuffer: # -- Build batched states -- batch_state = {} for key in self.state_keys: - batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to( - self.device - ) + batch_state[key] = torch.cat( + [t["state"][key] for t in list_of_transitions], dim=0 + ).to(self.device) if key.startswith("observation.image") and self.use_drq: batch_state[key] = self.image_augmentation_function(batch_state[key]) # -- Build batched actions -- - batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device) - - # -- Build batched rewards -- - batch_rewards = torch.tensor([t["reward"] for t in list_of_transitions], dtype=torch.float32).to( + batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to( self.device ) + # -- Build batched rewards -- + batch_rewards = torch.tensor( + [t["reward"] for t in list_of_transitions], dtype=torch.float32 + ).to(self.device) + # -- Build batched next states -- batch_next_state = {} for key in self.state_keys: - batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to( - self.device - ) + batch_next_state[key] = torch.cat( + [t["next_state"][key] for t in list_of_transitions], dim=0 + ).to(self.device) if key.startswith("observation.image") and self.use_drq: - batch_next_state[key] = self.image_augmentation_function(batch_next_state[key]) + batch_next_state[key] = self.image_augmentation_function( + batch_next_state[key] + ) # -- Build batched dones -- - batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to( - self.device - ) - batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to( - self.device - ) + batch_dones = torch.tensor( + [t["done"] for t in list_of_transitions], dtype=torch.float32 + ).to(self.device) + batch_dones = torch.tensor( + [t["done"] for t in list_of_transitions], dtype=torch.float32 + ).to(self.device) # Return a BatchTransition typed dict return BatchTransition( @@ -344,7 +354,13 @@ def concatenate_batch_transitions( ) -> BatchTransition: """NOTE: Be careful it change the left_batch_transitions in place""" left_batch_transitions["state"] = { - key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0) + key: torch.cat( + [ + left_batch_transitions["state"][key], + right_batch_transition["state"][key], + ], + dim=0, + ) for key in left_batch_transitions["state"] } left_batch_transitions["action"] = torch.cat( @@ -355,7 +371,11 @@ def concatenate_batch_transitions( ) left_batch_transitions["next_state"] = { key: torch.cat( - [left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0 + [ + left_batch_transitions["next_state"][key], + right_batch_transition["next_state"][key], + ], + dim=0, ) for key in left_batch_transitions["next_state"] } @@ -407,7 +427,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # dataset_stats=offline_dataset.meta.stats if not cfg.resume else None, # Hack: But if we do online traning, we do not need dataset_stats dataset_stats=None, - pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None, + pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) + if cfg.resume + else None, device=device, ) assert isinstance(policy, nn.Module) @@ -416,7 +438,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # TODO: Handle resume - num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) + num_learnable_params = sum( + p.numel() for p in policy.parameters() if p.requires_grad + ) num_total_params = sum(p.numel() for p in policy.parameters()) log_output_dir(out_dir) @@ -433,7 +457,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No obs = {key: obs[key].to(device, non_blocking=True) for key in obs} replay_buffer = ReplayBuffer( - capacity=cfg.training.online_buffer_capacity, device=device, state_keys=cfg.policy.input_shapes.keys() + capacity=cfg.training.online_buffer_capacity, + device=device, + state_keys=cfg.policy.input_shapes.keys(), ) batch_size = cfg.training.batch_size @@ -455,12 +481,16 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No if interaction_step >= cfg.training.online_step_before_learning: action = policy.select_action(batch=obs) - next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy()) + next_obs, reward, done, truncated, info = online_env.step( + action.cpu().numpy() + ) else: action = online_env.action_space.sample() next_obs, reward, done, truncated, info = online_env.step(action) # HACK - action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True) + action = torch.tensor(action, dtype=torch.float32).to( + device, non_blocking=True + ) # HACK: For maniskill # next_obs = preprocess_observation(next_obs) @@ -470,14 +500,20 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # Because we are using a single environment # we can safely assume that the episode is done if done[0] or truncated[0]: - logging.info(f"Global step {interaction_step}: Episode reward: {sum_reward_episode}") - logger.log_dict({"Sum episode reward": sum_reward_episode}, interaction_step) + logging.info( + f"Global step {interaction_step}: Episode reward: {sum_reward_episode}" + ) + logger.log_dict( + {"Sum episode reward": sum_reward_episode}, interaction_step + ) sum_reward_episode = 0 # HACK: This is for maniskill logging.info( f"global step {interaction_step}: episode success: {info['success'].float().item()} \n" ) - logger.log_dict({"Episode success": info["success"].float().item()}, interaction_step) + logger.log_dict( + {"Episode success": info["success"].float().item()}, interaction_step + ) replay_buffer.add( state=obs, @@ -551,7 +587,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No training_infos["loss_actor"] = loss_actor.item() - loss_temperature = policy.compute_loss_temperature(observations=observations) + loss_temperature = policy.compute_loss_temperature( + observations=observations + ) optimizers["temperature"].zero_grad() loss_temperature.backward() optimizers["temperature"].step() @@ -573,7 +611,9 @@ def train_cli(cfg: dict): ) -def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"): +def train_notebook( + out_dir=None, job_name=None, config_name="default", config_path="../configs" +): from hydra import compose, initialize hydra.core.global_hydra.GlobalHydra.instance().clear() diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index cdfea6b8..340e6516 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -94,8 +94,12 @@ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray: assert chw_float32_torch.dtype == torch.float32 assert chw_float32_torch.ndim == 3 c, h, w = chw_float32_torch.shape - assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}" - hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy() + assert ( + c < h and c < w + ), f"expect channel first images, but instead {chw_float32_torch.shape}" + hwc_uint8_numpy = ( + (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy() + ) return hwc_uint8_numpy diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py index 0fc21a8f..51dbf4c2 100644 --- a/lerobot/scripts/visualize_dataset_html.py +++ b/lerobot/scripts/visualize_dataset_html.py @@ -81,7 +81,11 @@ def run_server( static_folder: Path, template_folder: Path, ): - app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve()) + app = Flask( + __name__, + static_folder=static_folder.resolve(), + template_folder=template_folder.resolve(), + ) app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache @app.route("/") @@ -138,8 +142,12 @@ def run_server( ) ) - @app.route("///episode_") - def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes): + @app.route( + "///episode_" + ) + def show_episode( + dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes + ): repo_id = f"{dataset_namespace}/{dataset_name}" try: if dataset is None: @@ -171,15 +179,21 @@ def run_server( } if isinstance(dataset, LeRobotDataset): video_paths = [ - dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys + dataset.meta.get_video_file_path(episode_id, key) + for key in dataset.meta.video_keys ] videos_info = [ - {"url": url_for("static", filename=video_path), "filename": video_path.parent.name} + { + "url": url_for("static", filename=video_path), + "filename": video_path.parent.name, + } for video_path in video_paths ] tasks = dataset.meta.episodes[episode_id]["tasks"] else: - video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"] + video_keys = [ + key for key, ft in dataset.features.items() if ft["dtype"] == "video" + ] videos_info = [ { "url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/" @@ -198,16 +212,24 @@ def run_server( ) response.raise_for_status() # Split into lines and parse each line as JSON - tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()] + tasks_jsonl = [ + json.loads(line) for line in response.text.splitlines() if line.strip() + ] - filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id] + filtered_tasks_jsonl = [ + row for row in tasks_jsonl if row["episode_index"] == episode_id + ] tasks = filtered_tasks_jsonl[0]["tasks"] videos_info[0]["language_instruction"] = tasks if episodes is None: episodes = list( - range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes) + range( + dataset.num_episodes + if isinstance(dataset, LeRobotDataset) + else dataset.total_episodes + ) ) return render_template( @@ -255,7 +277,10 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index) else dataset.features[column_name].shape[0] ) - if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]: + if ( + "names" in dataset.features[column_name] + and dataset.features[column_name]["names"] + ): column_names = dataset.features[column_name]["names"] while not isinstance(column_names, list): column_names = list(column_names.values())[0] @@ -278,8 +303,12 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index) else: repo_id = dataset.repo_id - url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format( - episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index + url = ( + f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + + dataset.data_path.format( + episode_chunk=int(episode_index) // dataset.chunks_size, + episode_index=episode_index, + ) ) df = pd.read_parquet(url) data = df[selected_columns] # Select specific columns @@ -312,7 +341,9 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str] ] -def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]: +def get_episode_language_instruction( + dataset: LeRobotDataset, ep_index: int +) -> list[str]: # check if the dataset has language instructions if "language_instruction" not in dataset.features: return None @@ -323,7 +354,9 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"] # TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored # with the tf.tensor appearing in the string - return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)") + return language_instruction.removeprefix("tf.Tensor(b'").removesuffix( + "', shape=(), dtype=string)" + ) def get_dataset_info(repo_id: str) -> IterableNamespace: @@ -358,7 +391,9 @@ def visualize_dataset_html( if force_override: shutil.rmtree(output_dir) else: - logging.info(f"Output directory already exists. Loading from it: '{output_dir}'") + logging.info( + f"Output directory already exists. Loading from it: '{output_dir}'" + ) output_dir.mkdir(parents=True, exist_ok=True) diff --git a/tests/artifacts/datasets/save_dataset_to_safetensors.py b/tests/artifacts/datasets/save_dataset_to_safetensors.py index 74d42a3d..3159e8c8 100644 --- a/tests/artifacts/datasets/save_dataset_to_safetensors.py +++ b/tests/artifacts/datasets/save_dataset_to_safetensors.py @@ -52,7 +52,13 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"): save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors") # save 2 frames at the middle of first episode - i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2) + i = int( + ( + dataset.episode_data_index["to"][0].item() + - dataset.episode_data_index["from"][0].item() + ) + / 2 + ) save_file(dataset[i], repo_dir / f"frame_{i}.safetensors") save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors") diff --git a/tests/cameras/mock_pyrealsense2.py b/tests/cameras/mock_pyrealsense2.py index c477eb06..38baf5ba 100644 --- a/tests/cameras/mock_pyrealsense2.py +++ b/tests/cameras/mock_pyrealsense2.py @@ -30,7 +30,9 @@ class config: # noqa: N801 def enable_device(self, device_id: str): self.device_enabled = device_id - def enable_stream(self, stream_type: stream, width=None, height=None, color_format=None, fps=None): + def enable_stream( + self, stream_type: stream, width=None, height=None, color_format=None, fps=None + ): self.stream_type = stream_type # Overwrite default values when possible self.width = 848 if width is None else width diff --git a/tests/cameras/test_cameras.py b/tests/cameras/test_cameras.py index 868358ec..971ac4e0 100644 --- a/tests/cameras/test_cameras.py +++ b/tests/cameras/test_cameras.py @@ -37,7 +37,10 @@ pytest -sx 'tests/test_cameras.py::test_camera[intelrealsense-True]' import numpy as np import pytest -from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError +from lerobot.common.robot_devices.utils import ( + RobotDeviceAlreadyConnectedError, + RobotDeviceNotConnectedError, +) from tests.utils import TEST_CAMERA_TYPES, make_camera, require_camera # Maximum absolute difference between two consecutive images recorded by a camera. @@ -112,7 +115,11 @@ def test_camera(request, camera_type, mock): ) # TODO(rcadene): properly set `rtol` np.testing.assert_allclose( - color_image, async_color_image, rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE, err_msg=error_msg + color_image, + async_color_image, + rtol=1e-5, + atol=MAX_PIXEL_DIFFERENCE, + err_msg=error_msg, ) # Test disconnecting @@ -131,7 +138,11 @@ def test_camera(request, camera_type, mock): assert camera.color_mode == "bgr" bgr_color_image = camera.read() np.testing.assert_allclose( - color_image, bgr_color_image[:, :, [2, 1, 0]], rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE, err_msg=error_msg + color_image, + bgr_color_image[:, :, [2, 1, 0]], + rtol=1e-5, + atol=MAX_PIXEL_DIFFERENCE, + err_msg=error_msg, ) del camera @@ -166,7 +177,11 @@ def test_camera(request, camera_type, mock): rot_color_image = camera.read() np.testing.assert_allclose( - rot_color_image, manual_rot_img, rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE, err_msg=error_msg + rot_color_image, + manual_rot_img, + rtol=1e-5, + atol=MAX_PIXEL_DIFFERENCE, + err_msg=error_msg, ) del camera @@ -200,7 +215,9 @@ def test_save_images_from_cameras(tmp_path, request, camera_type, mock): if camera_type == "opencv": from lerobot.common.robot_devices.cameras.opencv import save_images_from_cameras elif camera_type == "intelrealsense": - from lerobot.common.robot_devices.cameras.intelrealsense import save_images_from_cameras + from lerobot.common.robot_devices.cameras.intelrealsense import ( + save_images_from_cameras, + ) # Small `record_time_s` to speedup unit tests save_images_from_cameras(tmp_path, record_time_s=0.02, mock=mock) diff --git a/tests/conftest.py b/tests/conftest.py index cc35768e..fa793606 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -91,7 +91,12 @@ def patch_builtins_input(monkeypatch): def pytest_addoption(parser): - parser.addoption("--seed", action="store", default="42", help="Set random seed for reproducibility") + parser.addoption( + "--seed", + action="store", + default="42", + help="Set random seed for reproducibility", + ) @pytest.fixture(autouse=True) diff --git a/tests/datasets/test_image_transforms.py b/tests/datasets/test_image_transforms.py index 352aba99..44f38b2e 100644 --- a/tests/datasets/test_image_transforms.py +++ b/tests/datasets/test_image_transforms.py @@ -364,10 +364,16 @@ def test_save_each_transform(img_tensor_factory, tmp_path): for transform in transforms: transform_dir = tmp_path / transform assert transform_dir.exists(), f"{transform} directory was not created." - assert any(transform_dir.iterdir()), f"No transformed images found in {transform} directory." + assert any( + transform_dir.iterdir() + ), f"No transformed images found in {transform} directory." # Check for specific files within each transform directory - expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"] + expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + [ + "min.png", + "max.png", + "mean.png", + ] for file_name in expected_files: assert (transform_dir / file_name).exists(), ( f"{file_name} was not found in {transform} directory." diff --git a/tests/datasets/test_image_writer.py b/tests/datasets/test_image_writer.py index 802fe0d3..6655b415 100644 --- a/tests/datasets/test_image_writer.py +++ b/tests/datasets/test_image_writer.py @@ -187,7 +187,9 @@ def test_save_image_torch(tmp_path, img_tensor_factory): writer.wait_until_done() assert fpath.exists() saved_image = np.array(Image.open(fpath)) - expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) + expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype( + np.uint8 + ) assert np.array_equal(expected_image, saved_image) finally: writer.stop() @@ -202,7 +204,9 @@ def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory): writer.wait_until_done() assert fpath.exists() saved_image = np.array(Image.open(fpath)) - expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) + expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype( + np.uint8 + ) assert np.array_equal(expected_image, saved_image) finally: writer.stop() @@ -292,7 +296,9 @@ def test_wait_until_done(tmp_path, img_array_factory): writer = AsyncImageWriter(num_processes=0, num_threads=4) try: num_images = 100 - image_arrays = [img_array_factory(height=500, width=500) for _ in range(num_images)] + image_arrays = [ + img_array_factory(height=500, width=500) for _ in range(num_images) + ] fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)] for image_array, fpath in zip(image_arrays, fpaths, strict=True): fpath.parent.mkdir(parents=True, exist_ok=True) diff --git a/tests/datasets/test_online_buffer.py b/tests/datasets/test_online_buffer.py index 339f6848..0285be1b 100644 --- a/tests/datasets/test_online_buffer.py +++ b/tests/datasets/test_online_buffer.py @@ -44,13 +44,23 @@ def make_new_buffer( return buffer, write_dir -def make_spoof_data_frames(n_episodes: int, n_frames_per_episode: int) -> dict[str, np.ndarray]: +def make_spoof_data_frames( + n_episodes: int, n_frames_per_episode: int +) -> dict[str, np.ndarray]: new_data = { - data_key: np.arange(n_frames_per_episode * n_episodes * np.prod(data_shape)).reshape(-1, *data_shape), + data_key: np.arange( + n_frames_per_episode * n_episodes * np.prod(data_shape) + ).reshape(-1, *data_shape), OnlineBuffer.INDEX_KEY: np.arange(n_frames_per_episode * n_episodes), - OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(np.arange(n_episodes), n_frames_per_episode), - OnlineBuffer.FRAME_INDEX_KEY: np.tile(np.arange(n_frames_per_episode), n_episodes), - OnlineBuffer.TIMESTAMP_KEY: np.tile(np.arange(n_frames_per_episode) / fps, n_episodes), + OnlineBuffer.EPISODE_INDEX_KEY: np.repeat( + np.arange(n_episodes), n_frames_per_episode + ), + OnlineBuffer.FRAME_INDEX_KEY: np.tile( + np.arange(n_frames_per_episode), n_episodes + ), + OnlineBuffer.TIMESTAMP_KEY: np.tile( + np.arange(n_frames_per_episode) / fps, n_episodes + ), } return new_data @@ -219,47 +229,72 @@ def test_compute_sampler_weights_trivial( online_dataset_size: int, online_sampling_ratio: float, ): - offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=offline_dataset_size) + offline_dataset = lerobot_dataset_factory( + tmp_path, total_episodes=1, total_frames=offline_dataset_size + ) online_dataset, _ = make_new_buffer() if online_dataset_size > 0: online_dataset.add_data( - make_spoof_data_frames(n_episodes=2, n_frames_per_episode=online_dataset_size // 2) + make_spoof_data_frames( + n_episodes=2, n_frames_per_episode=online_dataset_size // 2 + ) ) weights = compute_sampler_weights( - offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio + offline_dataset, + online_dataset=online_dataset, + online_sampling_ratio=online_sampling_ratio, ) if offline_dataset_size == 0 or online_dataset_size == 0: expected_weights = torch.ones(offline_dataset_size + online_dataset_size) elif online_sampling_ratio == 0: - expected_weights = torch.cat([torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)]) + expected_weights = torch.cat( + [torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)] + ) elif online_sampling_ratio == 1: - expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)]) + expected_weights = torch.cat( + [torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)] + ) expected_weights /= expected_weights.sum() torch.testing.assert_close(weights, expected_weights) def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path): # Arbitrarily set small dataset sizes, making sure to have uneven sizes. - offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4) + offline_dataset = lerobot_dataset_factory( + tmp_path, total_episodes=1, total_frames=4 + ) online_dataset, _ = make_new_buffer() - online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) + online_dataset.add_data( + make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2) + ) online_sampling_ratio = 0.8 weights = compute_sampler_weights( - offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio + offline_dataset, + online_dataset=online_dataset, + online_sampling_ratio=online_sampling_ratio, ) torch.testing.assert_close( weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]) ) -def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_dataset_factory, tmp_path): +def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n( + lerobot_dataset_factory, tmp_path +): # Arbitrarily set small dataset sizes, making sure to have uneven sizes. - offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4) + offline_dataset = lerobot_dataset_factory( + tmp_path, total_episodes=1, total_frames=4 + ) online_dataset, _ = make_new_buffer() - online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) + online_dataset.add_data( + make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2) + ) weights = compute_sampler_weights( - offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1 + offline_dataset, + online_dataset=online_dataset, + online_sampling_ratio=0.8, + online_drop_n_last_frames=1, ) torch.testing.assert_close( weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0]) @@ -268,9 +303,13 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_datase def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path): """Note: test copied from test_sampler.""" - offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=2) + offline_dataset = lerobot_dataset_factory( + tmp_path, total_episodes=1, total_frames=2 + ) online_dataset, _ = make_new_buffer() - online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) + online_dataset.add_data( + make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2) + ) weights = compute_sampler_weights( offline_dataset, diff --git a/tests/datasets/test_sampler.py b/tests/datasets/test_sampler.py index ee143f37..2b329a16 100644 --- a/tests/datasets/test_sampler.py +++ b/tests/datasets/test_sampler.py @@ -15,7 +15,9 @@ # limitations under the License. from datasets import Dataset -from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index +from lerobot.common.datasets.push_dataset_to_hub.utils import ( + calculate_episode_data_index, +) from lerobot.common.datasets.sampler import EpisodeAwareSampler from lerobot.common.datasets.utils import ( hf_transform_to_torch, diff --git a/tests/fixtures/constants.py b/tests/fixtures/constants.py index 5e5c762c..5e765382 100644 --- a/tests/fixtures/constants.py +++ b/tests/fixtures/constants.py @@ -20,17 +20,39 @@ DUMMY_MOTOR_FEATURES = { "action": { "dtype": "float32", "shape": (6,), - "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], + "names": [ + "shoulder_pan", + "shoulder_lift", + "elbow_flex", + "wrist_flex", + "wrist_roll", + "gripper", + ], }, "state": { "dtype": "float32", "shape": (6,), - "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], + "names": [ + "shoulder_pan", + "shoulder_lift", + "elbow_flex", + "wrist_flex", + "wrist_roll", + "gripper", + ], }, } DUMMY_CAMERA_FEATURES = { - "laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, - "phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, + "laptop": { + "shape": (480, 640, 3), + "names": ["height", "width", "channels"], + "info": None, + }, + "phone": { + "shape": (480, 640, 3), + "names": ["height", "width", "channels"], + "info": None, + }, } DEFAULT_FPS = 30 DUMMY_VIDEO_INFO = { diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index 531977da..b7b63614 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -23,7 +23,11 @@ import PIL.Image import pytest import torch -from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata +from lerobot.common.datasets.lerobot_dataset import ( + CODEBASE_VERSION, + LeRobotDataset, + LeRobotDatasetMetadata, +) from lerobot.common.datasets.utils import ( DEFAULT_CHUNK_SIZE, DEFAULT_FEATURES, @@ -54,7 +58,9 @@ def get_task_index(task_dicts: dict, task: str) -> int: @pytest.fixture(scope="session") def img_tensor_factory(): - def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor: + def _create_img_tensor( + height=100, width=100, channels=3, dtype=torch.float32 + ) -> torch.Tensor: return torch.rand((channels, height, width), dtype=dtype) return _create_img_tensor @@ -62,10 +68,14 @@ def img_tensor_factory(): @pytest.fixture(scope="session") def img_array_factory(): - def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray: + def _create_img_array( + height=100, width=100, channels=3, dtype=np.uint8 + ) -> np.ndarray: if np.issubdtype(dtype, np.unsignedinteger): # Int array in [0, 255] range - img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype) + img_array = np.random.randint( + 0, 256, size=(height, width, channels), dtype=dtype + ) elif np.issubdtype(dtype, np.floating): # Float array in [0, 1] range img_array = np.random.rand(height, width, channels).astype(dtype) @@ -94,10 +104,13 @@ def features_factory(): ) -> dict: if use_videos: camera_ft = { - key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items() + key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} + for key, ft in camera_features.items() } else: - camera_ft = {key: {"dtype": "image", **ft} for key, ft in camera_features.items()} + camera_ft = { + key: {"dtype": "image", **ft} for key, ft in camera_features.items() + } return { **motor_features, **camera_ft, @@ -215,7 +228,9 @@ def episodes_factory(tasks_factory): if total_episodes <= 0 or total_frames <= 0: raise ValueError("num_episodes and total_length must be positive integers.") if total_frames < total_episodes: - raise ValueError("total_length must be greater than or equal to num_episodes.") + raise ValueError( + "total_length must be greater than or equal to num_episodes." + ) if not tasks: min_tasks = 2 if multi_task else 1 @@ -223,10 +238,14 @@ def episodes_factory(tasks_factory): tasks = tasks_factory(total_tasks) if total_episodes < len(tasks) and not multi_task: - raise ValueError("The number of tasks should be less than the number of episodes.") + raise ValueError( + "The number of tasks should be less than the number of episodes." + ) # Generate random lengths that sum up to total_length - lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist() + lengths = np.random.multinomial( + total_frames, [1 / total_episodes] * total_episodes + ).tolist() tasks_list = [task_dict["task"] for task_dict in tasks.values()] num_tasks_available = len(tasks_list) @@ -234,9 +253,13 @@ def episodes_factory(tasks_factory): episodes = {} remaining_tasks = tasks_list.copy() for ep_idx in range(total_episodes): - num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1 + num_tasks_in_episode = ( + random.randint(1, min(3, num_tasks_available)) if multi_task else 1 + ) tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list - episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample))) + episode_tasks = random.sample( + tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)) + ) if remaining_tasks: for task in episode_tasks: remaining_tasks.remove(task) @@ -253,7 +276,9 @@ def episodes_factory(tasks_factory): @pytest.fixture(scope="session") -def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory): +def hf_dataset_factory( + features_factory, tasks_factory, episodes_factory, img_array_factory +): def _create_hf_dataset( features: dict | None = None, tasks: list[dict] | None = None, @@ -275,10 +300,15 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps)) frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int))) episode_index_col = np.concatenate( - (episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int)) + ( + episode_index_col, + np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int), + ) ) ep_task_index = get_task_index(tasks, ep_dict["tasks"][0]) - task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int))) + task_index = np.concatenate( + (task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)) + ) index_col = np.arange(len(episode_index_col)) @@ -290,7 +320,9 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar for _ in range(len(index_col)) ] elif ft["shape"][0] > 1 and ft["dtype"] != "video": - robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"]) + robot_cols[key] = np.random.random( + (len(index_col), ft["shape"][0]) + ).astype(ft["dtype"]) hf_features = get_hf_features_from_features(features) dataset = datasets.Dataset.from_dict( @@ -340,7 +372,9 @@ def lerobot_dataset_metadata_factory( tasks = tasks_factory(total_tasks=info["total_tasks"]) if not episodes: episodes = episodes_factory( - total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks + total_episodes=info["total_episodes"], + total_frames=info["total_frames"], + tasks=tasks, ) mock_snapshot_download = mock_snapshot_download_factory( @@ -392,7 +426,9 @@ def lerobot_dataset_factory( ) -> LeRobotDataset: if not info: info = info_factory( - total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks + total_episodes=total_episodes, + total_frames=total_frames, + total_tasks=total_tasks, ) if not stats: stats = stats_factory(features=info["features"]) @@ -408,7 +444,9 @@ def lerobot_dataset_factory( multi_task=multi_task, ) if not hf_dataset: - hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"]) + hf_dataset = hf_dataset_factory( + tasks=tasks, episodes=episode_dicts, fps=info["fps"] + ) mock_snapshot_download = mock_snapshot_download_factory( info=info, diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index 678d1f38..d869586f 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -102,7 +102,10 @@ def episode_path(episodes_factory): @pytest.fixture(scope="session") def single_episode_parquet_path(hf_dataset_factory, info_factory): def _create_single_episode_parquet( - dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None + dir: Path, + ep_idx: int = 0, + hf_dataset: datasets.Dataset | None = None, + info: dict | None = None, ) -> Path: if not info: info = info_factory() diff --git a/tests/fixtures/hub.py b/tests/fixtures/hub.py index aa2768e4..0bf4cc69 100644 --- a/tests/fixtures/hub.py +++ b/tests/fixtures/hub.py @@ -67,15 +67,21 @@ def mock_snapshot_download_factory( tasks = tasks_factory(total_tasks=info["total_tasks"]) if not episodes: episodes = episodes_factory( - total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks + total_episodes=info["total_episodes"], + total_frames=info["total_frames"], + tasks=tasks, ) if not hf_dataset: - hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"]) + hf_dataset = hf_dataset_factory( + tasks=tasks, episodes=episodes, fps=info["fps"] + ) def _extract_episode_index_from_path(fpath: str) -> int: path = Path(fpath) if path.suffix == ".parquet" and path.stem.startswith("episode_"): - episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0 + episode_index = int( + path.stem[len("episode_") :] + ) # 'episode_000000' -> 0 return episode_index else: return None @@ -100,12 +106,16 @@ def mock_snapshot_download_factory( for episode_dict in episodes.values(): ep_idx = episode_dict["episode_index"] ep_chunk = ep_idx // info["chunks_size"] - data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx) + data_path = info["data_path"].format( + episode_chunk=ep_chunk, episode_index=ep_idx + ) data_files.append(data_path) all_files.extend(data_files) allowed_files = filter_repo_objects( - all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns + all_files, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, ) # Create allowed files @@ -113,7 +123,9 @@ def mock_snapshot_download_factory( if rel_path.startswith("data/"): episode_index = _extract_episode_index_from_path(rel_path) if episode_index is not None: - _ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info) + _ = single_episode_parquet_path( + local_dir, episode_index, hf_dataset, info + ) if rel_path == INFO_PATH: _ = info_path(local_dir, info) elif rel_path == STATS_PATH: diff --git a/tests/motors/mock_dynamixel_sdk.py b/tests/motors/mock_dynamixel_sdk.py index ee399f96..387ff528 100644 --- a/tests/motors/mock_dynamixel_sdk.py +++ b/tests/motors/mock_dynamixel_sdk.py @@ -80,7 +80,9 @@ class GroupSyncRead: def addParam(self, motor_index): # noqa: N802 # Initialize motor default values if motor_index not in self.packet_handler.data: - self.packet_handler.data[motor_index] = get_default_motor_values(motor_index) + self.packet_handler.data[motor_index] = get_default_motor_values( + motor_index + ) def txRxPacket(self): # noqa: N802 return COMM_SUCCESS diff --git a/tests/motors/mock_scservo_sdk.py b/tests/motors/mock_scservo_sdk.py index 37f6d0d5..be6be756 100644 --- a/tests/motors/mock_scservo_sdk.py +++ b/tests/motors/mock_scservo_sdk.py @@ -91,7 +91,9 @@ class GroupSyncRead: def addParam(self, motor_index): # noqa: N802 # Initialize motor default values if motor_index not in self.packet_handler.data: - self.packet_handler.data[motor_index] = get_default_motor_values(motor_index) + self.packet_handler.data[motor_index] = get_default_motor_values( + motor_index + ) def txRxPacket(self): # noqa: N802 return COMM_SUCCESS diff --git a/tests/motors/test_motors.py b/tests/motors/test_motors.py index da7a5c54..c8013953 100644 --- a/tests/motors/test_motors.py +++ b/tests/motors/test_motors.py @@ -43,7 +43,10 @@ import time import numpy as np import pytest -from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError +from lerobot.common.robot_devices.utils import ( + RobotDeviceAlreadyConnectedError, + RobotDeviceNotConnectedError, +) from lerobot.scripts.find_motors_bus_port import find_port from tests.utils import TEST_MOTOR_TYPES, make_motors_bus, require_motor @@ -76,7 +79,9 @@ def test_configure_motors_all_ids_1(request, motor_type, mock): else: raise ValueError(motor_type) - input("Are you sure you want to re-configure the motors? Press enter to continue...") + input( + "Are you sure you want to re-configure the motors? Press enter to continue..." + ) # This test expect the configuration was already correct. motors_bus = make_motors_bus(motor_type, mock=mock) motors_bus.connect() diff --git a/tests/policies/hilserl/classifier/check_hiserl_reward_classifier.py b/tests/policies/hilserl/classifier/check_hiserl_reward_classifier.py index 55e6e381..84b96b6d 100644 --- a/tests/policies/hilserl/classifier/check_hiserl_reward_classifier.py +++ b/tests/policies/hilserl/classifier/check_hiserl_reward_classifier.py @@ -25,7 +25,10 @@ from torchmetrics import AUROC, Accuracy, F1Score, Precision, Recall from torchvision.datasets import CIFAR10 from torchvision.transforms import ToTensor -from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier, ClassifierConfig +from lerobot.common.policies.hilserl.classifier.modeling_classifier import ( + Classifier, + ClassifierConfig, +) BATCH_SIZE = 1000 LR = 0.1 @@ -43,7 +46,9 @@ def train_evaluate_multiclass_classifier(): logging.info( f"Start multiclass classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}" ) - multiclass_config = ClassifierConfig(model_name="microsoft/resnet-18", device=DEVICE, num_classes=10) + multiclass_config = ClassifierConfig( + model_name="microsoft/resnet-18", device=DEVICE, num_classes=10 + ) multiclass_classifier = Classifier(multiclass_config) trainset = CIFAR10(root="data", train=True, download=True, transform=ToTensor()) @@ -114,10 +119,18 @@ def train_evaluate_multiclass_classifier(): test_probs = torch.stack(test_probs) accuracy = Accuracy(task="multiclass", num_classes=multiclass_num_classes) - precision = Precision(task="multiclass", average="weighted", num_classes=multiclass_num_classes) - recall = Recall(task="multiclass", average="weighted", num_classes=multiclass_num_classes) - f1 = F1Score(task="multiclass", average="weighted", num_classes=multiclass_num_classes) - auroc = AUROC(task="multiclass", num_classes=multiclass_num_classes, average="weighted") + precision = Precision( + task="multiclass", average="weighted", num_classes=multiclass_num_classes + ) + recall = Recall( + task="multiclass", average="weighted", num_classes=multiclass_num_classes + ) + f1 = F1Score( + task="multiclass", average="weighted", num_classes=multiclass_num_classes + ) + auroc = AUROC( + task="multiclass", num_classes=multiclass_num_classes, average="weighted" + ) # Calculate metrics acc = accuracy(test_predictions, test_labels) @@ -146,18 +159,28 @@ def train_evaluate_binary_classifier(): new_label = float(1.0) if label == target_class else float(0.0) new_targets.append(new_label) - dataset.targets = new_targets # Replace the original labels with the binary ones + dataset.targets = ( + new_targets # Replace the original labels with the binary ones + ) return dataset - binary_train_dataset = CIFAR10(root="data", train=True, download=True, transform=ToTensor()) - binary_test_dataset = CIFAR10(root="data", train=False, download=True, transform=ToTensor()) + binary_train_dataset = CIFAR10( + root="data", train=True, download=True, transform=ToTensor() + ) + binary_test_dataset = CIFAR10( + root="data", train=False, download=True, transform=ToTensor() + ) # Apply one-vs-rest labeling binary_train_dataset = one_vs_rest(binary_train_dataset, target_binary_class) binary_test_dataset = one_vs_rest(binary_test_dataset, target_binary_class) - binary_trainloader = DataLoader(binary_train_dataset, batch_size=BATCH_SIZE, shuffle=True) - binary_testloader = DataLoader(binary_test_dataset, batch_size=BATCH_SIZE, shuffle=False) + binary_trainloader = DataLoader( + binary_train_dataset, batch_size=BATCH_SIZE, shuffle=True + ) + binary_testloader = DataLoader( + binary_test_dataset, batch_size=BATCH_SIZE, shuffle=False + ) binary_epoch = 1 diff --git a/tests/policies/hilserl/classifier/test_modelling_classifier.py b/tests/policies/hilserl/classifier/test_modelling_classifier.py index a3db4211..e8223a52 100644 --- a/tests/policies/hilserl/classifier/test_modelling_classifier.py +++ b/tests/policies/hilserl/classifier/test_modelling_classifier.py @@ -9,7 +9,9 @@ from tests.utils import require_package def test_classifier_output(): output = ClassifierOutput( - logits=torch.tensor([1, 2, 3]), probabilities=torch.tensor([0.1, 0.2, 0.3]), hidden_states=None + logits=torch.tensor([1, 2, 3]), + probabilities=torch.tensor([0.1, 0.2, 0.3]), + hidden_states=None, ) assert ( @@ -20,7 +22,9 @@ def test_classifier_output(): @require_package("transformers") def test_binary_classifier_with_default_params(): - from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier + from lerobot.common.policies.hilserl.classifier.modeling_classifier import ( + Classifier, + ) config = ClassifierConfig() classifier = Classifier(config) @@ -41,7 +45,9 @@ def test_binary_classifier_with_default_params(): @require_package("transformers") def test_multiclass_classifier(): - from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier + from lerobot.common.policies.hilserl.classifier.modeling_classifier import ( + Classifier, + ) num_classes = 5 config = ClassifierConfig(num_classes=num_classes) @@ -63,7 +69,9 @@ def test_multiclass_classifier(): @require_package("transformers") def test_default_device(): - from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier + from lerobot.common.policies.hilserl.classifier.modeling_classifier import ( + Classifier, + ) config = ClassifierConfig() assert config.device == "cpu" @@ -75,7 +83,9 @@ def test_default_device(): @require_package("transformers") def test_explicit_device_setup(): - from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier + from lerobot.common.policies.hilserl.classifier.modeling_classifier import ( + Classifier, + ) config = ClassifierConfig(device="meta") assert config.device == "meta" diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index 197aa732..b3df477a 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -172,7 +172,9 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs): # Test updating the policy (and test that it does not mutate the batch) batch_ = deepcopy(batch) policy.forward(batch) - assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass." + assert set(batch) == set( + batch_ + ), "Batch keys are not the same after a forward pass." assert all( torch.equal(batch[k], batch_[k]) if isinstance(batch[k], torch.Tensor) else batch[k] == batch_[k] for k in batch @@ -186,7 +188,9 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs): observation = preprocess_observation(observation) # send observation to device/gpu - observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation} + observation = { + key: observation[key].to(DEVICE, non_blocking=True) for key in observation + } # get the next action for the environment (also check that the observation batch is not modified) observation_ = deepcopy(observation) @@ -452,7 +456,9 @@ def test_act_temporal_ensembler(): batch_size = batch_seq.shape[0] # Exponential weighting (normalized). Unsqueeze once to match the position of the `episode_length` # dimension of `batch_seq`. - weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(-1) + weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze( + -1 + ) # Simulate stepping through a rollout and computing a batch of actions with model on each step. for i in range(episode_length): @@ -475,7 +481,8 @@ def test_act_temporal_ensembler(): episode_step_indices = torch.arange(i + 1)[-len(chunk_indices) :] seq_slice = batch_seq[:, episode_step_indices, chunk_indices] offline_avg = ( - einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum") / weights[: i + 1].sum() + einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum") + / weights[: i + 1].sum() ) # Sanity check. The average should be between the extrema. assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg) diff --git a/tests/robots/test_control_robot.py b/tests/robots/test_control_robot.py index 61d1caad..059bb79f 100644 --- a/tests/robots/test_control_robot.py +++ b/tests/robots/test_control_robot.py @@ -335,8 +335,12 @@ def test_record_with_event_rerecord_episode(tmp_path, request, robot_type, mock) ) dataset = record(robot, rec_cfg) - assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False" - assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False" + assert not mock_events[ + "rerecord_episode" + ], "`rerecord_episode` wasn't properly reset to False" + assert not mock_events[ + "exit_early" + ], "`exit_early` wasn't properly reset to False" assert len(dataset) == 1, "`dataset` should contain only 1 frame" @@ -391,7 +395,8 @@ def test_record_with_event_exit_early(tmp_path, request, robot_type, mock): @pytest.mark.parametrize( - "robot_type, mock, num_image_writer_processes", [("koch", True, 0), ("koch", True, 1)] + "robot_type, mock, num_image_writer_processes", + [("koch", True, 0), ("koch", True, 1)], ) @require_robot def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, num_image_writer_processes): diff --git a/tests/robots/test_robots.py b/tests/robots/test_robots.py index 71343eba..4616c747 100644 --- a/tests/robots/test_robots.py +++ b/tests/robots/test_robots.py @@ -105,7 +105,9 @@ def test_robot(tmp_path, request, robot_type, mock): assert "observation.state" in observation assert isinstance(observation["observation.state"], torch.Tensor) assert observation["observation.state"].ndim == 1 - dim_state = sum(len(robot.follower_arms[name].motors) for name in robot.follower_arms) + dim_state = sum( + len(robot.follower_arms[name].motors) for name in robot.follower_arms + ) assert observation["observation.state"].shape[0] == dim_state # Cameras for name in robot.cameras: @@ -116,7 +118,9 @@ def test_robot(tmp_path, request, robot_type, mock): assert "action" in action assert isinstance(action["action"], torch.Tensor) assert action["action"].ndim == 1 - dim_action = sum(len(robot.follower_arms[name].motors) for name in robot.follower_arms) + dim_action = sum( + len(robot.follower_arms[name].motors) for name in robot.follower_arms + ) assert action["action"].shape[0] == dim_action # TODO(rcadene): test if observation and action data are returned as expected diff --git a/tests/test_train_hilserl_classifier.py b/tests/test_train_hilserl_classifier.py index 8c1ad453..bc7a18bc 100644 --- a/tests/test_train_hilserl_classifier.py +++ b/tests/test_train_hilserl_classifier.py @@ -9,7 +9,9 @@ from hydra import compose, initialize_config_dir from torch import nn from torch.utils.data import Dataset -from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig +from lerobot.common.policies.hilserl.classifier.configuration_classifier import ( + ClassifierConfig, +) from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier from lerobot.scripts.train_hilserl_classifier import ( create_balanced_sampler, @@ -34,7 +36,9 @@ class MockDataset(Dataset): def make_dummy_model(): model_config = ClassifierConfig( - num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=1 + num_classes=2, + model_name="hf-tiny-model-private/tiny-random-ResNetModel", + num_cameras=1, ) model = Classifier(config=model_config) return model @@ -65,7 +69,9 @@ def test_create_balanced_sampler(): labels = [item["label"] for item in data] class_counts = torch.tensor([labels.count(0), labels.count(1)], dtype=torch.float32) class_weights = 1.0 / class_counts - expected_weights = torch.tensor([class_weights[label] for label in labels], dtype=torch.float32) + expected_weights = torch.tensor( + [class_weights[label] for label in labels], dtype=torch.float32 + ) # Test that the weights are correct assert torch.allclose(weights, expected_weights) @@ -149,7 +155,9 @@ def test_validate(): def test_train_epoch_multiple_cameras(): model_config = ClassifierConfig( - num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=2 + num_classes=2, + model_name="hf-tiny-model-private/tiny-random-ResNetModel", + num_cameras=2, ) model = Classifier(config=model_config) @@ -216,10 +224,16 @@ def test_resume_function( ): # Initialize Hydra test_file_dir = os.path.dirname(os.path.abspath(__file__)) - config_dir = os.path.abspath(os.path.join(test_file_dir, "..", "lerobot", "configs", "policy")) - assert os.path.exists(config_dir), f"Config directory does not exist at {config_dir}" + config_dir = os.path.abspath( + os.path.join(test_file_dir, "..", "lerobot", "configs", "policy") + ) + assert os.path.exists( + config_dir + ), f"Config directory does not exist at {config_dir}" - with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"): + with initialize_config_dir( + config_dir=config_dir, job_name="test_app", version_base="1.2" + ): cfg = compose( config_name="hilserl_classifier", overrides=[ @@ -244,7 +258,9 @@ def test_resume_function( mock_init_hydra_config.return_value = cfg # Mock dataset - dataset = MockDataset([{"image": torch.rand(3, 224, 224), "label": i % 2} for i in range(10)]) + dataset = MockDataset( + [{"image": torch.rand(3, 224, 224), "label": i % 2} for i in range(10)] + ) mock_dataset.return_value = dataset # Mock checkpoint handling diff --git a/tests/utils.py b/tests/utils.py index c49b5b9f..ca4d89bf 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -47,7 +47,9 @@ for motor_type in available_motors: OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0)) INTELREALSENSE_SERIAL_NUMBER = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614)) -DYNAMIXEL_PORT = os.environ.get("LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081") +DYNAMIXEL_PORT = os.environ.get( + "LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081" +) DYNAMIXEL_MOTORS = { "shoulder_pan": [1, "xl430-w250"], "shoulder_lift": [2, "xl430-w250"], @@ -57,7 +59,9 @@ DYNAMIXEL_MOTORS = { "gripper": [6, "xl330-m288"], } -FEETECH_PORT = os.environ.get("LEROBOT_TEST_FEETECH_PORT", "/dev/tty.usbmodem585A0080971") +FEETECH_PORT = os.environ.get( + "LEROBOT_TEST_FEETECH_PORT", "/dev/tty.usbmodem585A0080971" +) FEETECH_MOTORS = { "shoulder_pan": [1, "sts3215"], "shoulder_lift": [2, "sts3215"], @@ -156,9 +160,13 @@ def require_package_arg(func): if "required_packages" in arg_names: # Get the index of 'required_packages' and retrieve the value from args index = arg_names.index("required_packages") - required_packages = args[index] if len(args) > index else kwargs.get("required_packages") + required_packages = ( + args[index] if len(args) > index else kwargs.get("required_packages") + ) else: - raise ValueError("Function does not have 'required_packages' as an argument.") + raise ValueError( + "Function does not have 'required_packages' as an argument." + ) if required_packages is None: return func(*args, **kwargs) @@ -215,11 +223,17 @@ def require_robot(func): mock = kwargs.get("mock") if robot_type is None: - raise ValueError("The 'robot_type' must be an argument of the test function.") + raise ValueError( + "The 'robot_type' must be an argument of the test function." + ) if request is None: - raise ValueError("The 'request' fixture must be an argument of the test function.") + raise ValueError( + "The 'request' fixture must be an argument of the test function." + ) if mock is None: - raise ValueError("The 'mock' variable must be an argument of the test function.") + raise ValueError( + "The 'mock' variable must be an argument of the test function." + ) # Run test with a real robot. Skip test if robot connection fails. if not mock and not request.getfixturevalue("is_robot_available"): @@ -239,11 +253,17 @@ def require_camera(func): mock = kwargs.get("mock") if request is None: - raise ValueError("The 'request' fixture must be an argument of the test function.") + raise ValueError( + "The 'request' fixture must be an argument of the test function." + ) if camera_type is None: - raise ValueError("The 'camera_type' must be an argument of the test function.") + raise ValueError( + "The 'camera_type' must be an argument of the test function." + ) if mock is None: - raise ValueError("The 'mock' variable must be an argument of the test function.") + raise ValueError( + "The 'mock' variable must be an argument of the test function." + ) if not mock and not request.getfixturevalue("is_camera_available"): pytest.skip(f"A {camera_type} camera is not available.") @@ -262,11 +282,17 @@ def require_motor(func): mock = kwargs.get("mock") if request is None: - raise ValueError("The 'request' fixture must be an argument of the test function.") + raise ValueError( + "The 'request' fixture must be an argument of the test function." + ) if motor_type is None: - raise ValueError("The 'motor_type' must be an argument of the test function.") + raise ValueError( + "The 'motor_type' must be an argument of the test function." + ) if mock is None: - raise ValueError("The 'mock' variable must be an argument of the test function.") + raise ValueError( + "The 'mock' variable must be an argument of the test function." + ) if not mock and not request.getfixturevalue("is_motor_available"): pytest.skip(f"A {motor_type} motor is not available.") @@ -285,7 +311,14 @@ def mock_calibration_dir(calibration_dir): "start_pos": [1442, 843, 2166, 2849, 1988, 1835], "end_pos": [2440, 1869, -1106, -1848, -926, 3235], "calib_mode": ["DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "LINEAR"], - "motor_names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], + "motor_names": [ + "shoulder_pan", + "shoulder_lift", + "elbow_flex", + "wrist_flex", + "wrist_roll", + "gripper", + ], } Path(str(calibration_dir)).mkdir(parents=True, exist_ok=True) with open(calibration_dir / "main_follower.json", "w") as f: