From 0ea27704f633c433f5b694153864f09aac4ca960 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 24 Mar 2025 13:41:27 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- benchmarks/video/run_video_benchmark.py | 59 ++--- examples/1_load_lerobot_dataset.py | 9 +- examples/3_train_policy.py | 25 +-- examples/advanced/1_add_image_transforms.py | 8 +- .../advanced/2_calculate_validation_loss.py | 4 +- lerobot/__init__.py | 14 +- lerobot/common/datasets/compute_stats.py | 35 +-- lerobot/common/datasets/factory.py | 12 +- lerobot/common/datasets/image_writer.py | 20 +- lerobot/common/datasets/lerobot_dataset.py | 170 ++++----------- lerobot/common/datasets/online_buffer.py | 37 +--- .../datasets/push_dataset_to_hub/utils.py | 12 +- lerobot/common/datasets/transforms.py | 24 +- lerobot/common/datasets/utils.py | 94 ++------ .../datasets/v2/convert_dataset_v1_to_v2.py | 106 +++------ .../v21/_remove_language_instruction.py | 16 +- .../v21/convert_dataset_v20_to_v21.py | 4 +- lerobot/common/datasets/v21/convert_stats.py | 11 +- lerobot/common/datasets/video_utils.py | 16 +- lerobot/common/envs/configs.py | 24 +- lerobot/common/envs/factory.py | 31 +-- lerobot/common/envs/utils.py | 8 +- lerobot/common/optim/factory.py | 10 +- lerobot/common/optim/optimizers.py | 4 +- lerobot/common/optim/schedulers.py | 18 +- .../common/policies/act/configuration_act.py | 4 +- lerobot/common/policies/act/modeling_act.py | 175 ++++----------- .../diffusion/configuration_diffusion.py | 9 +- .../policies/diffusion/modeling_diffusion.py | 128 +++-------- lerobot/common/policies/factory.py | 12 +- .../hilserl/classifier/modeling_classifier.py | 25 +-- lerobot/common/policies/normalize.py | 36 +-- .../conversion_scripts/compare_with_jax.py | 8 +- .../conversion_scripts/conversion_utils.py | 4 +- .../convert_pi0_to_hf_lerobot.py | 8 +- lerobot/common/policies/pi0/modeling_pi0.py | 88 ++------ .../policies/pi0/paligemma_with_expert.py | 28 +-- lerobot/common/policies/pretrained.py | 24 +- .../common/policies/sac/configuration_sac.py | 4 +- lerobot/common/policies/sac/modeling_sac.py | 109 +++------- .../policies/tdmpc/configuration_tdmpc.py | 12 +- .../common/policies/tdmpc/modeling_tdmpc.py | 130 +++-------- .../policies/vqbet/configuration_vqbet.py | 5 +- .../common/policies/vqbet/modeling_vqbet.py | 173 ++++----------- lerobot/common/policies/vqbet/vqbet_utils.py | 181 ++++------------ .../common/robot_devices/cameras/configs.py | 16 +- .../robot_devices/cameras/intelrealsense.py | 36 +-- .../common/robot_devices/cameras/opencv.py | 42 +--- .../common/robot_devices/control_configs.py | 4 +- lerobot/common/robot_devices/control_utils.py | 27 +-- .../common/robot_devices/motors/dynamixel.py | 80 ++----- .../common/robot_devices/motors/feetech.py | 80 ++----- .../common/robot_devices/robots/configs.py | 8 +- .../robots/dynamixel_calibration.py | 33 +-- .../robots/feetech_calibration.py | 95 +++----- .../robot_devices/robots/lekiwi_remote.py | 16 +- .../robot_devices/robots/manipulator.py | 76 ++----- .../robots/mobile_manipulator.py | 25 +-- .../common/robot_devices/robots/stretch.py | 24 +- lerobot/common/utils/hub.py | 8 +- lerobot/common/utils/import_utils.py | 4 +- lerobot/common/utils/io_utils.py | 27 +-- lerobot/common/utils/logging_utils.py | 12 +- lerobot/common/utils/random_utils.py | 4 +- lerobot/common/utils/utils.py | 41 +--- lerobot/common/utils/wandb_utils.py | 12 +- lerobot/configs/default.py | 4 +- lerobot/configs/eval.py | 4 +- lerobot/configs/parser.py | 20 +- lerobot/configs/policies.py | 10 +- lerobot/configs/train.py | 22 +- lerobot/scripts/configure_motor.py | 40 +--- lerobot/scripts/control_robot.py | 14 +- lerobot/scripts/control_sim_robot.py | 33 +-- lerobot/scripts/display_sys_info.py | 10 +- lerobot/scripts/eval.py | 93 ++------ lerobot/scripts/eval_on_robot.py | 26 +-- lerobot/scripts/find_motors_bus_port.py | 8 +- lerobot/scripts/server/actor_server.py | 70 ++---- lerobot/scripts/server/buffer.py | 153 ++++--------- lerobot/scripts/server/crop_dataset_roi.py | 8 +- .../server/end_effector_control_utils.py | 87 +++----- lerobot/scripts/server/find_joint_limits.py | 8 +- lerobot/scripts/server/gym_manipulator.py | 205 +++++------------- lerobot/scripts/server/kinematics.py | 40 +--- lerobot/scripts/server/learner_server.py | 128 ++++------- lerobot/scripts/server/learner_service.py | 12 +- .../scripts/server/maniskill_manipulator.py | 35 ++- lerobot/scripts/server/network_utils.py | 23 +- lerobot/scripts/server/utils.py | 3 +- lerobot/scripts/train.py | 16 +- lerobot/scripts/train_hilserl_classifier.py | 70 ++---- lerobot/scripts/train_sac.py | 94 +++----- lerobot/scripts/visualize_dataset.py | 8 +- lerobot/scripts/visualize_dataset_html.py | 69 ++---- lerobot/scripts/visualize_image_transforms.py | 24 +- .../datasets/save_dataset_to_safetensors.py | 8 +- .../policies/save_policy_to_safetensors.py | 24 +- tests/cameras/mock_pyrealsense2.py | 4 +- tests/configs/test_plugin_loading.py | 4 +- tests/datasets/test_compute_stats.py | 53 +---- tests/datasets/test_datasets.py | 32 +-- tests/datasets/test_delta_timestamps.py | 32 +-- tests/datasets/test_image_transforms.py | 32 +-- tests/datasets/test_image_writer.py | 12 +- tests/datasets/test_online_buffer.py | 76 ++----- tests/fixtures/dataset_factories.py | 79 ++----- tests/fixtures/files.py | 4 +- tests/fixtures/hub.py | 16 +- tests/fixtures/optimizers.py | 4 +- tests/motors/mock_dynamixel_sdk.py | 4 +- tests/motors/mock_scservo_sdk.py | 4 +- tests/motors/test_motors.py | 4 +- tests/optim/test_schedulers.py | 4 +- .../check_hiserl_reward_classifier.py | 40 +--- tests/policies/test_policies.py | 59 ++--- tests/robots/test_control_robot.py | 24 +- tests/robots/test_robots.py | 12 +- tests/test_train_hilserl_classifier.py | 20 +- tests/utils.py | 62 ++---- tests/utils/test_logging_utils.py | 20 +- tests/utils/test_random_utils.py | 8 +- tests/utils/test_train_utils.py | 4 +- 123 files changed, 1161 insertions(+), 3425 deletions(-) diff --git a/benchmarks/video/run_video_benchmark.py b/benchmarks/video/run_video_benchmark.py index 64240282..5612fed6 100644 --- a/benchmarks/video/run_video_benchmark.py +++ b/benchmarks/video/run_video_benchmark.py @@ -85,9 +85,7 @@ def get_directory_size(directory: Path) -> int: return total_size -def load_original_frames( - imgs_dir: Path, timestamps: list[float], fps: int -) -> torch.Tensor: +def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> torch.Tensor: frames = [] for ts in timestamps: idx = int(ts * fps) @@ -129,9 +127,7 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None: hf_dataset = dataset.hf_dataset.with_format(None) # We only save images from the first camera - img_keys = [ - key for key in hf_dataset.features if key.startswith("observation.image") - ] + img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")] imgs_dataset = hf_dataset.select_columns(img_keys[0]) for i, item in enumerate( @@ -148,9 +144,7 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None: break -def sample_timestamps( - timestamps_mode: str, ep_num_images: int, fps: int -) -> list[float]: +def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> list[float]: # Start at 5 to allow for 2_frames_4_space and 6_frames idx = random.randint(5, ep_num_images - 1) match timestamps_mode: @@ -175,9 +169,7 @@ def decode_video_frames( backend: str, ) -> torch.Tensor: if backend in ["pyav", "video_reader"]: - return decode_video_frames_torchvision( - video_path, timestamps, tolerance_s, backend - ) + return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend) else: raise NotImplementedError(backend) @@ -204,9 +196,7 @@ def benchmark_decoding( } with time_benchmark: - frames = decode_video_frames( - video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend - ) + frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend) result["load_time_video_ms"] = time_benchmark.result_ms / num_frames with time_benchmark: @@ -215,18 +205,12 @@ def benchmark_decoding( frames_np, original_frames_np = frames.numpy(), original_frames.numpy() for i in range(num_frames): - result["mse_values"].append( - mean_squared_error(original_frames_np[i], frames_np[i]) - ) + result["mse_values"].append(mean_squared_error(original_frames_np[i], frames_np[i])) result["psnr_values"].append( - peak_signal_noise_ratio( - original_frames_np[i], frames_np[i], data_range=1.0 - ) + peak_signal_noise_ratio(original_frames_np[i], frames_np[i], data_range=1.0) ) result["ssim_values"].append( - structural_similarity( - original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0 - ) + structural_similarity(original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0) ) if save_frames and sample == 0: @@ -246,9 +230,7 @@ def benchmark_decoding( # As these samples are independent, we run them in parallel threads to speed up the benchmark. with ThreadPoolExecutor(max_workers=num_workers) as executor: futures = [executor.submit(process_sample, i) for i in range(num_samples)] - for future in tqdm( - as_completed(futures), total=num_samples, desc="samples", leave=False - ): + for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False): result = future.result() load_times_video_ms.append(result["load_time_video_ms"]) load_times_images_ms.append(result["load_time_images_ms"]) @@ -312,9 +294,7 @@ def benchmark_encoding_decoding( desc="decodings (timestamps_modes)", leave=False, ): - for backend in tqdm( - decoding_cfg["backends"], desc="decodings (backends)", leave=False - ): + for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False): benchmark_row = benchmark_decoding( imgs_dir, video_path, @@ -392,23 +372,14 @@ def main( imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_") # We only use the first episode save_first_episode(imgs_dir, dataset) - for key, values in tqdm( - encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False - ): + for key, values in tqdm(encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False): for value in tqdm(values, desc=f"encodings ({key})", leave=False): encoding_cfg = BASE_ENCODING.copy() encoding_cfg["vcodec"] = video_codec encoding_cfg["pix_fmt"] = pixel_format encoding_cfg[key] = value - args_path = Path( - "_".join(str(value) for value in encoding_cfg.values()) - ) - video_path = ( - output_dir - / "videos" - / args_path - / f"{repo_id.replace('/', '_')}.mp4" - ) + args_path = Path("_".join(str(value) for value in encoding_cfg.values())) + video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4" benchmark_table += benchmark_encoding_decoding( dataset, video_path, @@ -434,9 +405,7 @@ def main( # Concatenate all results df_list = [pd.read_csv(csv_path) for csv_path in file_paths] concatenated_df = pd.concat(df_list, ignore_index=True) - concatenated_path = ( - output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv" - ) + concatenated_path = output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv" concatenated_df.to_csv(concatenated_path, header=True, index=False) diff --git a/examples/1_load_lerobot_dataset.py b/examples/1_load_lerobot_dataset.py index 8877a7f8..724b9451 100644 --- a/examples/1_load_lerobot_dataset.py +++ b/examples/1_load_lerobot_dataset.py @@ -43,10 +43,7 @@ pprint(lerobot.available_datasets) # You can also browse through the datasets created/ported by the community on the hub using the hub api: hub_api = HfApi() -repo_ids = [ - info.id - for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"]) -] +repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])] pprint(repo_ids) # Or simply explore them in your web browser directly at: @@ -61,9 +58,7 @@ ds_meta = LeRobotDatasetMetadata(repo_id) # structure of the dataset without downloading the actual data yet (only metadata files — which are # lightweight). print(f"Total number of episodes: {ds_meta.total_episodes}") -print( - f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}" -) +print(f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}") print(f"Frames per second used during data collection: {ds_meta.fps}") print(f"Robot type: {ds_meta.robot_type}") print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n") diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index ab784193..5193d5ba 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -51,18 +51,12 @@ def main(): # - dataset stats: for normalization and denormalization of input/outputs dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht") features = dataset_to_policy_features(dataset_metadata.features) - output_features = { - key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION - } - input_features = { - key: ft for key, ft in features.items() if key not in output_features - } + output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} + input_features = {key: ft for key, ft in features.items() if key not in output_features} # Policies are initialized with a configuration class, in this case `DiffusionConfig`. For this example, # we'll just use the defaults and so no arguments other than input/output features need to be passed. - cfg = DiffusionConfig( - input_features=input_features, output_features=output_features - ) + cfg = DiffusionConfig(input_features=input_features, output_features=output_features) # We can now instantiate our policy with this config and the dataset stats. policy = DiffusionPolicy(cfg, dataset_stats=dataset_metadata.stats) @@ -72,12 +66,8 @@ def main(): # Another policy-dataset interaction is with the delta_timestamps. Each policy expects a given number frames # which can differ for inputs, outputs and rewards (if there are some). delta_timestamps = { - "observation.image": [ - i / dataset_metadata.fps for i in cfg.observation_delta_indices - ], - "observation.state": [ - i / dataset_metadata.fps for i in cfg.observation_delta_indices - ], + "observation.image": [i / dataset_metadata.fps for i in cfg.observation_delta_indices], + "observation.state": [i / dataset_metadata.fps for i in cfg.observation_delta_indices], "action": [i / dataset_metadata.fps for i in cfg.action_delta_indices], } @@ -129,10 +119,7 @@ def main(): done = False while not done: for batch in dataloader: - batch = { - k: (v.to(device) if isinstance(v, torch.Tensor) else v) - for k, v in batch.items() - } + batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} loss, _ = policy.forward(batch) loss.backward() optimizer.step() diff --git a/examples/advanced/1_add_image_transforms.py b/examples/advanced/1_add_image_transforms.py index 78dc6152..f1460926 100644 --- a/examples/advanced/1_add_image_transforms.py +++ b/examples/advanced/1_add_image_transforms.py @@ -48,14 +48,10 @@ transforms = v2.Compose( ) # Create another LeRobotDataset with the defined transformations -transformed_dataset = LeRobotDataset( - dataset_repo_id, episodes=[0], image_transforms=transforms -) +transformed_dataset = LeRobotDataset(dataset_repo_id, episodes=[0], image_transforms=transforms) # Get a frame from the transformed dataset -transformed_frame = transformed_dataset[first_idx][ - transformed_dataset.meta.camera_keys[0] -] +transformed_frame = transformed_dataset[first_idx][transformed_dataset.meta.camera_keys[0]] # Create a directory to store output images output_dir = Path("outputs/image_transforms") diff --git a/examples/advanced/2_calculate_validation_loss.py b/examples/advanced/2_calculate_validation_loss.py index e7eb7fb4..230724e6 100644 --- a/examples/advanced/2_calculate_validation_loss.py +++ b/examples/advanced/2_calculate_validation_loss.py @@ -90,9 +90,7 @@ def main(): train_dataset = LeRobotDataset( "lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps ) - val_dataset = LeRobotDataset( - "lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps - ) + val_dataset = LeRobotDataset("lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps) print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}") print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}") diff --git a/lerobot/__init__.py b/lerobot/__init__.py index dec96226..d61e4853 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -164,11 +164,7 @@ available_real_world_datasets = [ ] available_datasets = sorted( - set( - itertools.chain( - *available_datasets_per_env.values(), available_real_world_datasets - ) - ) + set(itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)) ) # lists all available policies from `lerobot/common/policies` @@ -209,13 +205,9 @@ available_policies_per_env = { "aloha_real": ["act_aloha_real"], } -env_task_pairs = [ - (env, task) for env, tasks in available_tasks_per_env.items() for task in tasks -] +env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks] env_dataset_pairs = [ - (env, dataset) - for env, datasets in available_datasets_per_env.items() - for dataset in datasets + (env, dataset) for env, datasets in available_datasets_per_env.items() for dataset in datasets ] env_dataset_policy_triplets = [ (env, dataset, policy) diff --git a/lerobot/common/datasets/compute_stats.py b/lerobot/common/datasets/compute_stats.py index 0202b0e2..ce3113c3 100644 --- a/lerobot/common/datasets/compute_stats.py +++ b/lerobot/common/datasets/compute_stats.py @@ -46,18 +46,14 @@ def sample_indices(data_len: int) -> list[int]: return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist() -def auto_downsample_height_width( - img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300 -): +def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300): _, height, width = img.shape if max(width, height) < max_size_threshold: # no downsampling needed return img - downsample_factor = ( - int(width / target_size) if width > height else int(height / target_size) - ) + downsample_factor = int(width / target_size) if width > height else int(height / target_size) return img[:, ::downsample_factor, ::downsample_factor] @@ -79,9 +75,7 @@ def sample_images(image_paths: list[str]) -> np.ndarray: return images -def get_feature_stats( - array: np.ndarray, axis: tuple, keepdims: bool -) -> dict[str, np.ndarray]: +def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]: return { "min": np.min(array, axis=axis, keepdims=keepdims), "max": np.max(array, axis=axis, keepdims=keepdims), @@ -91,9 +85,7 @@ def get_feature_stats( } -def compute_episode_stats( - episode_data: dict[str, list[str] | np.ndarray], features: dict -) -> dict: +def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict: ep_stats = {} for key, data in episode_data.items(): if features[key]["dtype"] == "string": @@ -107,15 +99,12 @@ def compute_episode_stats( axes_to_reduce = 0 # compute stats over the first axis keepdims = data.ndim == 1 # keep as np.array - ep_stats[key] = get_feature_stats( - ep_ft_array, axis=axes_to_reduce, keepdims=keepdims - ) + ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims) # finally, we normalize and remove batch dim for images if features[key]["dtype"] in ["image", "video"]: ep_stats[key] = { - k: v if k == "count" else np.squeeze(v / 255.0, axis=0) - for k, v in ep_stats[key].items() + k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items() } return ep_stats @@ -130,17 +119,11 @@ def _assert_type_and_shape(stats_list: list[dict[str, dict]]): f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead." ) if v.ndim == 0: - raise ValueError( - "Number of dimensions must be at least 1, and is 0 instead." - ) + raise ValueError("Number of dimensions must be at least 1, and is 0 instead.") if k == "count" and v.shape != (1,): - raise ValueError( - f"Shape of 'count' must be (1), but is {v.shape} instead." - ) + raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.") if "image" in fkey and k != "count" and v.shape != (3, 1, 1): - raise ValueError( - f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead." - ) + raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.") def aggregate_feature_stats( diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index f7a12838..38c01b42 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -58,9 +58,7 @@ def resolve_delta_timestamps( if key == "action" and cfg.action_delta_indices is not None: delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices] if key.startswith("observation.") and cfg.observation_delta_indices is not None: - delta_timestamps[key] = [ - i / ds_meta.fps for i in cfg.observation_delta_indices - ] + delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices] if len(delta_timestamps) == 0: delta_timestamps = None @@ -81,9 +79,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas LeRobotDataset | MultiLeRobotDataset """ image_transforms = ( - ImageTransforms(cfg.dataset.image_transforms) - if cfg.dataset.image_transforms.enable - else None + ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None ) if isinstance(cfg.dataset.repo_id, str): @@ -117,8 +113,6 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas if cfg.dataset.use_imagenet_stats: for key in dataset.meta.camera_keys: for stats_type, stats in IMAGENET_STATS.items(): - dataset.meta.stats[key][stats_type] = torch.tensor( - stats, dtype=torch.float32 - ) + dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32) return dataset diff --git a/lerobot/common/datasets/image_writer.py b/lerobot/common/datasets/image_writer.py index 81e5bae9..6fc0ee2f 100644 --- a/lerobot/common/datasets/image_writer.py +++ b/lerobot/common/datasets/image_writer.py @@ -38,14 +38,10 @@ def safe_stop_image_writer(func): return wrapper -def image_array_to_pil_image( - image_array: np.ndarray, range_check: bool = True -) -> PIL.Image.Image: +def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image: # TODO(aliberts): handle 1 channel and 4 for depth images if image_array.ndim != 3: - raise ValueError( - f"The array has {image_array.ndim} dimensions, but 3 is expected for an image." - ) + raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.") if image_array.shape[0] == 3: # Transpose from pytorch convention (C, H, W) to (H, W, C) @@ -131,9 +127,7 @@ class AsyncImageWriter: self._stopped = False if num_threads <= 0 and num_processes <= 0: - raise ValueError( - "Number of threads and processes must be greater than zero." - ) + raise ValueError("Number of threads and processes must be greater than zero.") if self.num_processes == 0: # Use threading @@ -147,16 +141,12 @@ class AsyncImageWriter: # Use multiprocessing self.queue = multiprocessing.JoinableQueue() for _ in range(self.num_processes): - p = multiprocessing.Process( - target=worker_process, args=(self.queue, self.num_threads) - ) + p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads)) p.daemon = True p.start() self.processes.append(p) - def save_image( - self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path - ): + def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path): if isinstance(image, torch.Tensor): # Convert tensor to numpy array to minimize main process time image = image.cpu().numpy() diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 2b6c787b..42d7a3fc 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -108,9 +108,7 @@ class LeRobotDatasetMetadata: self.episodes = load_episodes(self.root) if self._version < packaging.version.parse("v2.1"): self.stats = load_stats(self.root) - self.episodes_stats = backward_compatible_episodes_stats( - self.stats, self.episodes - ) + self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes) else: self.episodes_stats = load_episodes_stats(self.root) self.stats = aggregate_stats(list(self.episodes_stats.values())) @@ -141,9 +139,7 @@ class LeRobotDatasetMetadata: def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: ep_chunk = self.get_episode_chunk(ep_index) - fpath = self.video_path.format( - episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index - ) + fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index) return Path(fpath) def get_episode_chunk(self, ep_index: int) -> int: @@ -187,11 +183,7 @@ class LeRobotDatasetMetadata: @property def camera_keys(self) -> list[str]: """Keys to access visual modalities (regardless of their storage method).""" - return [ - key - for key, ft in self.features.items() - if ft["dtype"] in ["video", "image"] - ] + return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]] @property def names(self) -> dict[str, list | dict]: @@ -240,9 +232,7 @@ class LeRobotDatasetMetadata: Given a task in natural language, add it to the dictionary of tasks. """ if task in self.task_to_task_index: - raise ValueError( - f"The task '{task}' already exists and can't be added twice." - ) + raise ValueError(f"The task '{task}' already exists and can't be added twice.") task_index = self.info["total_tasks"] self.task_to_task_index[task] = task_index @@ -285,11 +275,7 @@ class LeRobotDatasetMetadata: write_episode(episode_dict, self.root) self.episodes_stats[episode_index] = episode_stats - self.stats = ( - aggregate_stats([self.stats, episode_stats]) - if self.stats - else episode_stats - ) + self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats write_episode_stats(episode_index, episode_stats, self.root) def update_video_info(self) -> None: @@ -299,9 +285,7 @@ class LeRobotDatasetMetadata: """ for key in self.video_keys: if not self.features[key].get("info", None): - video_path = self.root / self.get_video_file_path( - ep_index=0, vid_key=key - ) + video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key) self.info["features"][key]["info"] = get_video_info(video_path) def __repr__(self): @@ -353,17 +337,13 @@ class LeRobotDatasetMetadata: # as this would break the dict flattening in the stats computation, which uses '/' as separator for key in features: if "/" in key: - raise ValueError( - f"Feature names should not contain '/'. Found '/' in feature '{key}'." - ) + raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.") features = {**features, **DEFAULT_FEATURES} obj.tasks, obj.task_to_task_index = {}, {} obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {} - obj.info = create_empty_dataset_info( - CODEBASE_VERSION, fps, robot_type, features, use_videos - ) + obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos) if len(obj.video_keys) > 0 and not use_videos: raise ValueError() write_json(obj.info, obj.root / INFO_PATH) @@ -494,9 +474,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.episodes = episodes self.tolerance_s = tolerance_s self.revision = revision if revision else CODEBASE_VERSION - self.video_backend = ( - video_backend if video_backend else get_safe_default_codec() - ) + self.video_backend = video_backend if video_backend else get_safe_default_codec() self.delta_indices = None # Unused attributes @@ -509,39 +487,28 @@ class LeRobotDataset(torch.utils.data.Dataset): self.meta = LeRobotDatasetMetadata( self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync ) - if self.episodes is not None and self.meta._version >= packaging.version.parse( - "v2.1" - ): - episodes_stats = [ - self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes - ] + if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"): + episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes] self.stats = aggregate_stats(episodes_stats) # Load actual data try: if force_cache_sync: raise FileNotFoundError - assert all( - (self.root / fpath).is_file() - for fpath in self.get_episodes_file_paths() - ) + assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths()) self.hf_dataset = self.load_hf_dataset() except (AssertionError, FileNotFoundError, NotADirectoryError): self.revision = get_safe_version(self.repo_id, self.revision) self.download_episodes(download_videos) self.hf_dataset = self.load_hf_dataset() - self.episode_data_index = get_episode_data_index( - self.meta.episodes, self.episodes - ) + self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) # Check timestamps timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy() episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy() ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()} - check_timestamps_sync( - timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s - ) + check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s) # Setup delta_indices if self.delta_timestamps is not None: @@ -593,9 +560,7 @@ class LeRobotDataset(torch.utils.data.Dataset): else: hub_api.upload_folder(**upload_kwargs) - if not hub_api.file_exists( - self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch - ): + if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch): card = create_lerobot_dataset_card( tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs ) @@ -603,12 +568,8 @@ class LeRobotDataset(torch.utils.data.Dataset): if tag_version: with contextlib.suppress(RevisionNotFoundError): - hub_api.delete_tag( - self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset" - ) - hub_api.create_tag( - self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset" - ) + hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset") + hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") def pull_from_repo( self, @@ -640,11 +601,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns) def get_episodes_file_paths(self) -> list[Path]: - episodes = ( - self.episodes - if self.episodes is not None - else list(range(self.meta.total_episodes)) - ) + episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes)) fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes] if len(self.meta.video_keys) > 0: video_files = [ @@ -662,10 +619,7 @@ class LeRobotDataset(torch.utils.data.Dataset): path = str(self.root / "data") hf_dataset = load_dataset("parquet", data_dir=path, split="train") else: - files = [ - str(self.root / self.meta.get_data_file_path(ep_idx)) - for ep_idx in self.episodes - ] + files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes] hf_dataset = load_dataset("parquet", data_files=files, split="train") # TODO(aliberts): hf_dataset.set_format("torch") @@ -675,9 +629,7 @@ class LeRobotDataset(torch.utils.data.Dataset): def create_hf_dataset(self) -> datasets.Dataset: features = get_hf_features_from_features(self.features) ft_dict = {col: [] for col in features} - hf_dataset = datasets.Dataset.from_dict( - ft_dict, features=features, split="train" - ) + hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train") # TODO(aliberts): hf_dataset.set_format("torch") hf_dataset.set_transform(hf_transform_to_torch) @@ -691,20 +643,12 @@ class LeRobotDataset(torch.utils.data.Dataset): @property def num_frames(self) -> int: """Number of frames in selected episodes.""" - return ( - len(self.hf_dataset) - if self.hf_dataset is not None - else self.meta.total_frames - ) + return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames @property def num_episodes(self) -> int: """Number of episodes selected.""" - return ( - len(self.episodes) - if self.episodes is not None - else self.meta.total_episodes - ) + return len(self.episodes) if self.episodes is not None else self.meta.total_episodes @property def features(self) -> dict[str, dict]: @@ -718,24 +662,16 @@ class LeRobotDataset(torch.utils.data.Dataset): else: return get_hf_features_from_features(self.features) - def _get_query_indices( - self, idx: int, ep_idx: int - ) -> tuple[dict[str, list[int | bool]]]: + def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]: ep_start = self.episode_data_index["from"][ep_idx] ep_end = self.episode_data_index["to"][ep_idx] query_indices = { - key: [ - max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) - for delta in delta_idx - ] + key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx] for key, delta_idx in self.delta_indices.items() } padding = { # Pad values outside of current episode range f"{key}_is_pad": torch.BoolTensor( - [ - (idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) - for delta in delta_idx - ] + [(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx] ) for key, delta_idx in self.delta_indices.items() } @@ -763,9 +699,7 @@ class LeRobotDataset(torch.utils.data.Dataset): if key not in self.meta.video_keys } - def _query_videos( - self, query_timestamps: dict[str, list[float]], ep_idx: int - ) -> dict[str, torch.Tensor]: + def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]: """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault. This probably happens because a memory reference to the video loader is created in @@ -774,9 +708,7 @@ class LeRobotDataset(torch.utils.data.Dataset): item = {} for vid_key, query_ts in query_timestamps.items(): video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) - frames = decode_video_frames( - video_path, query_ts, self.tolerance_s, self.video_backend - ) + frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend) item[vid_key] = frames.squeeze(0) return item @@ -830,9 +762,7 @@ class LeRobotDataset(torch.utils.data.Dataset): ) def create_episode_buffer(self, episode_index: int | None = None) -> dict: - current_ep_idx = ( - self.meta.total_episodes if episode_index is None else episode_index - ) + current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index ep_buffer = {} # size and task are special cases that are not in self.features ep_buffer["size"] = 0 @@ -841,17 +771,13 @@ class LeRobotDataset(torch.utils.data.Dataset): ep_buffer[key] = current_ep_idx if key == "episode_index" else [] return ep_buffer - def _get_image_file_path( - self, episode_index: int, image_key: str, frame_index: int - ) -> Path: + def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path: fpath = DEFAULT_IMAGE_PATH.format( image_key=image_key, episode_index=episode_index, frame_index=frame_index ) return self.root / fpath - def _save_image( - self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path - ) -> None: + def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None: if self.image_writer is None: if isinstance(image, torch.Tensor): image = image.cpu().numpy() @@ -877,9 +803,7 @@ class LeRobotDataset(torch.utils.data.Dataset): # Automatically add frame_index and timestamp to episode buffer frame_index = self.episode_buffer["size"] - timestamp = ( - frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps - ) + timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps self.episode_buffer["frame_index"].append(frame_index) self.episode_buffer["timestamp"].append(timestamp) @@ -930,9 +854,7 @@ class LeRobotDataset(torch.utils.data.Dataset): episode_tasks = list(set(tasks)) episode_index = episode_buffer["episode_index"] - episode_buffer["index"] = np.arange( - self.meta.total_frames, self.meta.total_frames + episode_length - ) + episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length) episode_buffer["episode_index"] = np.full((episode_length,), episode_index) # Add new tasks to the tasks dictionary @@ -942,9 +864,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.meta.add_task(task) # Given tasks in natural language, find their corresponding task indices - episode_buffer["task_index"] = np.array( - [self.meta.get_task_index(task) for task in tasks] - ) + episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks]) for key, ft in self.features.items(): # index, episode_index, task_index are already processed above, and image and video @@ -994,9 +914,7 @@ class LeRobotDataset(torch.utils.data.Dataset): def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None: episode_dict = {key: episode_buffer[key] for key in self.hf_features} - ep_dataset = datasets.Dataset.from_dict( - episode_dict, features=self.hf_features, split="train" - ) + ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train") ep_dataset = embed_images(ep_dataset) self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset]) self.hf_dataset.set_transform(hf_transform_to_torch) @@ -1115,9 +1033,7 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.delta_timestamps = None obj.delta_indices = None obj.episode_data_index = None - obj.video_backend = ( - video_backend if video_backend is not None else get_safe_default_codec() - ) + obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() return obj @@ -1142,9 +1058,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): super().__init__() self.repo_ids = repo_ids self.root = Path(root) if root else HF_LEROBOT_HOME - self.tolerances_s = ( - tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001) - ) + self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001) # Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which # are handled by this class. self._datasets = [ @@ -1223,13 +1137,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): def features(self) -> datasets.Features: features = {} for dataset in self._datasets: - features.update( - { - k: v - for k, v in dataset.hf_features.items() - if k not in self.disabled_features - } - ) + features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features}) return features @property @@ -1290,9 +1198,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): continue break else: - raise AssertionError( - "We expect the loop to break out as long as the index is within bounds." - ) + raise AssertionError("We expect the loop to break out as long as the index is within bounds.") item = self._datasets[dataset_idx][idx - start_idx] item["dataset_index"] = torch.tensor(dataset_idx) for data_key in self.disabled_features: diff --git a/lerobot/common/datasets/online_buffer.py b/lerobot/common/datasets/online_buffer.py index e31206fa..88ae39d6 100644 --- a/lerobot/common/datasets/online_buffer.py +++ b/lerobot/common/datasets/online_buffer.py @@ -131,9 +131,7 @@ class OnlineBuffer(torch.utils.data.Dataset): else: self._delta_timestamps = None - def _make_data_spec( - self, data_spec: dict[str, Any], buffer_capacity: int - ) -> dict[str, dict[str, Any]]: + def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]: """Makes the data spec for np.memmap.""" if any(k.startswith("_") for k in data_spec): raise ValueError( @@ -208,9 +206,7 @@ class OnlineBuffer(torch.utils.data.Dataset): # Shift the incoming indices if necessary. if self.num_frames > 0: - last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][ - next_index - 1 - ] + last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1] last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1] data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1 data[OnlineBuffer.INDEX_KEY] += last_data_index + 1 @@ -245,11 +241,7 @@ class OnlineBuffer(torch.utils.data.Dataset): @property def num_episodes(self) -> int: return len( - np.unique( - self._data[OnlineBuffer.EPISODE_INDEX_KEY][ - self._data[OnlineBuffer.OCCUPANCY_MASK_KEY] - ] - ) + np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]) ) @property @@ -287,9 +279,7 @@ class OnlineBuffer(torch.utils.data.Dataset): self._data[OnlineBuffer.OCCUPANCY_MASK_KEY], ) )[0] - episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][ - episode_data_indices - ] + episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices] for data_key in self.delta_timestamps: # Note: The logic in this loop is copied from `load_previous_and_future_frames`. @@ -306,8 +296,7 @@ class OnlineBuffer(torch.utils.data.Dataset): # Check violated query timestamps are all outside the episode range. assert ( - (query_ts[is_pad] < episode_timestamps[0]) - | (episode_timestamps[-1] < query_ts[is_pad]) + (query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad]) ).all(), ( f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}" ") inside the episode range." @@ -322,9 +311,7 @@ class OnlineBuffer(torch.utils.data.Dataset): def get_data_by_key(self, key: str) -> torch.Tensor: """Returns all data for a given data key as a Tensor.""" - return torch.from_numpy( - self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]] - ) + return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]) def compute_sampler_weights( @@ -355,19 +342,13 @@ def compute_sampler_weights( - Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not included here to avoid adding complexity. """ - if len(offline_dataset) == 0 and ( - online_dataset is None or len(online_dataset) == 0 - ): - raise ValueError( - "At least one of `offline_dataset` or `online_dataset` should be contain data." - ) + if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0): + raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.") if (online_dataset is None) ^ (online_sampling_ratio is None): raise ValueError( "`online_dataset` and `online_sampling_ratio` must be provided together or not at all." ) - offline_sampling_ratio = ( - 0 if online_sampling_ratio is None else 1 - online_sampling_ratio - ) + offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio weights = [] diff --git a/lerobot/common/datasets/push_dataset_to_hub/utils.py b/lerobot/common/datasets/push_dataset_to_hub/utils.py index 13997c81..7d06f3cf 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/utils.py +++ b/lerobot/common/datasets/push_dataset_to_hub/utils.py @@ -45,9 +45,7 @@ def concatenate_episodes(ep_dicts): return data_dict -def save_images_concurrently( - imgs_array: numpy.array, out_dir: Path, max_workers: int = 4 -): +def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4): out_dir = Path(out_dir) out_dir.mkdir(parents=True, exist_ok=True) @@ -57,10 +55,7 @@ def save_images_concurrently( num_images = len(imgs_array) with ThreadPoolExecutor(max_workers=max_workers) as executor: - [ - executor.submit(save_image, imgs_array[i], i, out_dir) - for i in range(num_images) - ] + [executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)] def get_default_encoding() -> dict: @@ -69,8 +64,7 @@ def get_default_encoding() -> dict: return { k: v.default for k, v in signature.parameters.items() - if v.default is not inspect.Parameter.empty - and k in ["vcodec", "pix_fmt", "g", "crf"] + if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"] } diff --git a/lerobot/common/datasets/transforms.py b/lerobot/common/datasets/transforms.py index 2dbd5904..720c939b 100644 --- a/lerobot/common/datasets/transforms.py +++ b/lerobot/common/datasets/transforms.py @@ -58,9 +58,7 @@ class RandomSubsetApply(Transform): elif not isinstance(n_subset, int): raise TypeError("n_subset should be an int or None") elif not (1 <= n_subset <= len(transforms)): - raise ValueError( - f"n_subset should be in the interval [1, {len(transforms)}]" - ) + raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]") self.transforms = transforms total = sum(p) @@ -121,36 +119,26 @@ class SharpnessJitter(Transform): def _check_input(self, sharpness): if isinstance(sharpness, (int, float)): if sharpness < 0: - raise ValueError( - "If sharpness is a single number, it must be non negative." - ) + raise ValueError("If sharpness is a single number, it must be non negative.") sharpness = [1.0 - sharpness, 1.0 + sharpness] sharpness[0] = max(sharpness[0], 0.0) elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2: sharpness = [float(v) for v in sharpness] else: - raise TypeError( - f"{sharpness=} should be a single number or a sequence with length 2." - ) + raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.") if not 0.0 <= sharpness[0] <= sharpness[1]: - raise ValueError( - f"sharpnesss values should be between (0., inf), but got {sharpness}." - ) + raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.") return float(sharpness[0]), float(sharpness[1]) def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: - sharpness_factor = ( - torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item() - ) + sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item() return {"sharpness_factor": sharpness_factor} def transform(self, inpt: Any, params: dict[str, Any]) -> Any: sharpness_factor = params["sharpness_factor"] - return self._call_kernel( - F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor - ) + return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor) @dataclass diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index a68fce97..59178761 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -52,15 +52,9 @@ STATS_PATH = "meta/stats.json" EPISODES_STATS_PATH = "meta/episodes_stats.jsonl" TASKS_PATH = "meta/tasks.jsonl" -DEFAULT_VIDEO_PATH = ( - "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" -) -DEFAULT_PARQUET_PATH = ( - "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet" -) -DEFAULT_IMAGE_PATH = ( - "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png" -) +DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" +DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet" +DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png" DATASET_CARD_TEMPLATE = """ --- @@ -135,9 +129,7 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict: elif isinstance(value, (int, float)): serialized_dict[key] = value else: - raise NotImplementedError( - f"The value '{value}' of type '{type(value)}' is not supported." - ) + raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.") return unflatten_dict(serialized_dict) @@ -216,10 +208,7 @@ def write_task(task_index: int, task: dict, local_dir: Path): def load_tasks(local_dir: Path) -> tuple[dict, dict]: tasks = load_jsonlines(local_dir / TASKS_PATH) - tasks = { - item["task_index"]: item["task"] - for item in sorted(tasks, key=lambda x: x["task_index"]) - } + tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])} task_to_task_index = {task: task_index for task_index, task in tasks.items()} return tasks, task_to_task_index @@ -230,10 +219,7 @@ def write_episode(episode: dict, local_dir: Path): def load_episodes(local_dir: Path) -> dict: episodes = load_jsonlines(local_dir / EPISODES_PATH) - return { - item["episode_index"]: item - for item in sorted(episodes, key=lambda x: x["episode_index"]) - } + return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])} def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path): @@ -286,9 +272,7 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): elif first_item is None: pass else: - items_dict[key] = [ - x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key] - ] + items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]] return items_dict @@ -341,9 +325,7 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> Otherwise, will throw a `CompatibilityError`. """ target_version = ( - packaging.version.parse(version) - if not isinstance(version, packaging.version.Version) - else version + packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version ) hub_versions = get_repo_versions(repo_id) @@ -364,16 +346,12 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> return f"v{target_version}" compatibles = [ - v - for v in hub_versions - if v.major == target_version.major and v.minor <= target_version.minor + v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor ] if compatibles: return_version = max(compatibles) if return_version < target_version: - logging.warning( - f"Revision {version} for {repo_id} not found, using version v{return_version}" - ) + logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}") return f"v{return_version}" lower_major = [v for v in hub_versions if v.major < target_version.major] @@ -480,9 +458,7 @@ def create_empty_dataset_info( def get_episode_data_index( episode_dicts: dict[dict], episodes: list[int] | None = None ) -> dict[str, torch.Tensor]: - episode_lengths = { - ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items() - } + episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()} if episodes is not None: episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes} @@ -532,9 +508,7 @@ def check_timestamps_sync( # Mask to ignore differences at the boundaries between episodes mask = np.ones(len(diffs), dtype=bool) - ignored_diffs = ( - episode_data_index["to"][:-1] - 1 - ) # indices at the end of each episode + ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode mask[ignored_diffs] = False filtered_within_tolerance = within_tolerance[mask] @@ -580,14 +554,10 @@ def check_delta_timestamps( """ outside_tolerance = {} for key, delta_ts in delta_timestamps.items(): - within_tolerance = [ - abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts - ] + within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts] if not all(within_tolerance): outside_tolerance[key] = [ - ts - for ts, is_within in zip(delta_ts, within_tolerance, strict=True) - if not is_within + ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within ] if len(outside_tolerance) > 0: @@ -605,9 +575,7 @@ def check_delta_timestamps( return True -def get_delta_indices( - delta_timestamps: dict[str, list[float]], fps: int -) -> dict[str, list[int]]: +def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]: delta_indices = {} for key, delta_ts in delta_timestamps.items(): delta_indices[key] = [round(d * fps) for d in delta_ts] @@ -672,9 +640,7 @@ def create_lerobot_dataset_card( ], ) - card_template = ( - importlib.resources.files("lerobot.common.datasets") / "card_template.md" - ).read_text() + card_template = (importlib.resources.files("lerobot.common.datasets") / "card_template.md").read_text() return DatasetCard.from_template( card_data=card_data, @@ -743,18 +709,14 @@ def validate_frame(frame: dict, features: dict): expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"} actual_features = set(frame.keys()) - error_message = validate_features_presence( - actual_features, expected_features, optional_features - ) + error_message = validate_features_presence(actual_features, expected_features, optional_features) if "task" in frame: error_message += validate_feature_string("task", frame["task"]) common_features = actual_features & (expected_features | optional_features) for name in common_features - {"task"}: - error_message += validate_feature_dtype_and_shape( - name, features[name], frame[name] - ) + error_message += validate_feature_dtype_and_shape(name, features[name], frame[name]) if error_message: raise ValueError(error_message) @@ -777,9 +739,7 @@ def validate_features_presence( return error_message -def validate_feature_dtype_and_shape( - name: str, feature: dict, value: np.ndarray | PILImage.Image | str -): +def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str): expected_dtype = feature["dtype"] expected_shape = feature["shape"] if is_valid_numpy_dtype_string(expected_dtype): @@ -789,9 +749,7 @@ def validate_feature_dtype_and_shape( elif expected_dtype == "string": return validate_feature_string(name, value) else: - raise NotImplementedError( - f"The feature dtype '{expected_dtype}' is not implemented yet." - ) + raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.") def validate_feature_numpy_array( @@ -813,17 +771,13 @@ def validate_feature_numpy_array( return error_message -def validate_feature_image_or_video( - name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image -): +def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image): # Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads. error_message = "" if isinstance(value, np.ndarray): actual_shape = value.shape c, h, w = expected_shape - if len(actual_shape) != 3 or ( - actual_shape != (c, h, w) and actual_shape != (h, w, c) - ): + if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)): error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n" elif isinstance(value, PILImage.Image): pass @@ -854,9 +808,7 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: ) if episode_buffer["size"] == 0: - raise ValueError( - "You must add one or several frames with `add_frame` before calling `add_episode`." - ) + raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.") buffer_keys = set(episode_buffer.keys()) - {"task", "size"} if not buffer_keys == set(features): diff --git a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py index 80807e4c..49b93ede 100644 --- a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py +++ b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py @@ -218,9 +218,7 @@ def get_features_from_hf_dataset( dtype = ft.feature.dtype shape = (ft.length,) motor_names = ( - robot_config["names"][key] - if robot_config - else [f"motor_{i}" for i in range(ft.length)] + robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)] ) assert len(motor_names) == shape[0] names = {"motors": motor_names} @@ -244,15 +242,11 @@ def get_features_from_hf_dataset( return features -def add_task_index_by_episodes( - dataset: Dataset, tasks_by_episodes: dict -) -> tuple[Dataset, list[str]]: +def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]: df = dataset.to_pandas() tasks = list(set(tasks_by_episodes.values())) tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)} - episodes_to_task_index = { - ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items() - } + episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()} df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int) features = dataset.features @@ -269,19 +263,10 @@ def add_task_index_from_tasks_col( # HACK: This is to clean some of the instructions in our version of Open X datasets prefix_to_clean = "tf.Tensor(b'" suffix_to_clean = "', shape=(), dtype=string)" - df[tasks_col] = ( - df[tasks_col] - .str.removeprefix(prefix_to_clean) - .str.removesuffix(suffix_to_clean) - ) + df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean) # Create task_index col - tasks_by_episode = ( - df.groupby("episode_index")[tasks_col] - .unique() - .apply(lambda x: x.tolist()) - .to_dict() - ) + tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict() tasks = df[tasks_col].unique().tolist() tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)} df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int) @@ -306,9 +291,7 @@ def split_parquet_by_episodes( for ep_chunk in range(total_chunks): ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes) - chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format( - episode_chunk=ep_chunk - ) + chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk) (output_dir / chunk_dir).mkdir(parents=True, exist_ok=True) for ep_idx in range(ep_chunk_start, ep_chunk_end): ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) @@ -340,9 +323,7 @@ def move_videos( videos_moved = False video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")] if len(video_files) == 0: - video_files = [ - str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4") - ] + video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")] videos_moved = True # Videos have already been moved assert len(video_files) == total_episodes * len(video_keys) @@ -373,9 +354,7 @@ def move_videos( target_path = DEFAULT_VIDEO_PATH.format( episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx ) - video_file = V1_VIDEO_FILE.format( - video_key=vid_key, episode_index=ep_idx - ) + video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx) if len(video_dirs) == 1: video_path = video_dirs[0] / video_file else: @@ -392,9 +371,7 @@ def move_videos( subprocess.run(["git", "push"], cwd=work_dir, check=True) -def fix_lfs_video_files_tracking( - work_dir: Path, lfs_untracked_videos: list[str] -) -> None: +def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None: """ HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case, there's no other option than to download the actual files and reupload them with lfs tracking. @@ -418,14 +395,10 @@ def fix_lfs_video_files_tracking( subprocess.run(["git", "push"], cwd=work_dir, check=True) -def fix_gitattributes( - work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path -) -> None: +def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None: shutil.copyfile(clean_gittatributes, current_gittatributes) subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True) - subprocess.run( - ["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True - ) + subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True) subprocess.run(["git", "push"], cwd=work_dir, check=True) @@ -462,9 +435,7 @@ def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[st return [f for f in video_files if f not in lfs_tracked_files] -def get_videos_info( - repo_id: str, local_dir: Path, video_keys: list[str], branch: str -) -> dict: +def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict: # Assumes first episode video_files = [ DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0) @@ -539,31 +510,19 @@ def convert_dataset( if single_task: tasks_by_episodes = dict.fromkeys(episode_indices, single_task) dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes) - tasks_by_episodes = { - ep_idx: [task] for ep_idx, task in tasks_by_episodes.items() - } + tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()} elif tasks_path: tasks_by_episodes = load_json(tasks_path) - tasks_by_episodes = { - int(ep_idx): task for ep_idx, task in tasks_by_episodes.items() - } + tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()} dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes) - tasks_by_episodes = { - ep_idx: [task] for ep_idx, task in tasks_by_episodes.items() - } + tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()} elif tasks_col: - dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col( - dataset, tasks_col - ) + dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col) else: raise ValueError - assert set(tasks) == { - task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks - } - tasks = [ - {"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks) - ] + assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks} + tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)] write_jsonlines(tasks, v20_dir / TASKS_PATH) features["task_index"] = { "dtype": "int64", @@ -593,9 +552,7 @@ def convert_dataset( clean_gitattr, branch, ) - videos_info = get_videos_info( - repo_id, v1x_dir, video_keys=video_keys, branch=branch - ) + videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch) for key in video_keys: features[key]["shape"] = ( videos_info[key].pop("video.height"), @@ -603,22 +560,15 @@ def convert_dataset( videos_info[key].pop("video.channels"), ) features[key]["video_info"] = videos_info[key] - assert math.isclose( - videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3 - ) + assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3) if "encoding" in metadata_v1: - assert ( - videos_info[key]["video.pix_fmt"] - == metadata_v1["encoding"]["pix_fmt"] - ) + assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"] else: assert metadata_v1.get("video", 0) == 0 videos_info = None # Split data into 1 parquet file by episode - episode_lengths = split_parquet_by_episodes( - dataset, total_episodes, total_chunks, v20_dir - ) + episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir) if robot_config is not None: robot_type = robot_config.type @@ -656,14 +606,10 @@ def convert_dataset( } write_json(metadata_v2_0, v20_dir / INFO_PATH) convert_stats_to_json(v1x_dir, v20_dir) - card = create_lerobot_dataset_card( - tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs - ) + card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs) with contextlib.suppress(EntryNotFoundError, HfHubHTTPError): - hub_api.delete_folder( - repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch - ) + hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch) with contextlib.suppress(EntryNotFoundError, HfHubHTTPError): hub_api.delete_folder( @@ -674,9 +620,7 @@ def convert_dataset( ) with contextlib.suppress(EntryNotFoundError, HfHubHTTPError): - hub_api.delete_folder( - repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch - ) + hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch) hub_api.upload_folder( repo_id=repo_id, diff --git a/lerobot/common/datasets/v21/_remove_language_instruction.py b/lerobot/common/datasets/v21/_remove_language_instruction.py index 66461c59..643ddd3f 100644 --- a/lerobot/common/datasets/v21/_remove_language_instruction.py +++ b/lerobot/common/datasets/v21/_remove_language_instruction.py @@ -35,30 +35,22 @@ def fix_dataset(repo_id: str) -> str: dataset_info = get_dataset_config_info(repo_id, "default") with SuppressWarnings(): - lerobot_metadata = LeRobotDatasetMetadata( - repo_id, revision=V20, force_cache_sync=True - ) + lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True) - meta_features = { - key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video" - } + meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"} parquet_features = set(dataset_info.features) diff_parquet_meta = parquet_features - meta_features diff_meta_parquet = meta_features - parquet_features if diff_parquet_meta: - raise ValueError( - f"In parquet not in info.json: {parquet_features - meta_features}" - ) + raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}") if not diff_meta_parquet: return f"{repo_id}: skipped (no diff)" if diff_meta_parquet: - logging.warning( - f"In info.json not in parquet: {meta_features - parquet_features}" - ) + logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}") assert diff_meta_parquet == {"language_instruction"} lerobot_metadata.features.pop("language_instruction") write_info(lerobot_metadata.info, lerobot_metadata.root) diff --git a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py index 3cd6e52b..2cc0fc94 100644 --- a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py +++ b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py @@ -99,9 +99,7 @@ def convert_dataset( repo_type="dataset", ) - hub_api.create_tag( - repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset" - ) + hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") if __name__ == "__main__": diff --git a/lerobot/common/datasets/v21/convert_stats.py b/lerobot/common/datasets/v21/convert_stats.py index a0df5337..b0bd6de5 100644 --- a/lerobot/common/datasets/v21/convert_stats.py +++ b/lerobot/common/datasets/v21/convert_stats.py @@ -26,9 +26,7 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.utils import write_episode_stats -def sample_episode_video_frames( - dataset: LeRobotDataset, episode_index: int, ft_key: str -) -> np.ndarray: +def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray: ep_len = dataset.meta.episodes[episode_index]["length"] sampled_indices = sample_indices(ep_len) query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices}) @@ -51,14 +49,11 @@ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int): axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0 keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1 - ep_stats[key] = get_feature_stats( - ep_ft_data, axis=axes_to_reduce, keepdims=keepdims - ) + ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims) if ft["dtype"] in ["image", "video"]: # remove batch dim ep_stats[key] = { - k: v if k == "count" else np.squeeze(v, axis=0) - for k, v in ep_stats[key].items() + k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items() } dataset.meta.episodes_stats[ep_idx] = ep_stats diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index 3a2fe7ca..c38d570d 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -65,9 +65,7 @@ def decode_video_frames( if backend == "torchcodec": return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s) elif backend in ["pyav", "video_reader"]: - return decode_video_frames_torchvision( - video_path, timestamps, tolerance_s, backend - ) + return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend) else: raise ValueError(f"Unsupported video backend: {backend}") @@ -346,9 +344,7 @@ def get_audio_info(video_path: Path | str) -> dict: "json", str(video_path), ] - result = subprocess.run( - ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True - ) + result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) if result.returncode != 0: raise RuntimeError(f"Error running ffprobe: {result.stderr}") @@ -362,9 +358,7 @@ def get_audio_info(video_path: Path | str) -> dict: "has_audio": True, "audio.channels": audio_stream_info.get("channels", None), "audio.codec": audio_stream_info.get("codec_name", None), - "audio.bit_rate": int(audio_stream_info["bit_rate"]) - if audio_stream_info.get("bit_rate") - else None, + "audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None, "audio.sample_rate": int(audio_stream_info["sample_rate"]) if audio_stream_info.get("sample_rate") else None, @@ -386,9 +380,7 @@ def get_video_info(video_path: Path | str) -> dict: "json", str(video_path), ] - result = subprocess.run( - ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True - ) + result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) if result.returncode != 0: raise RuntimeError(f"Error running ffprobe: {result.stderr}") diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index 69c31684..cf90048a 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -61,16 +61,10 @@ class AlohaEnv(EnvConfig): def __post_init__(self): if self.obs_type == "pixels": - self.features["top"] = PolicyFeature( - type=FeatureType.VISUAL, shape=(480, 640, 3) - ) + self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3)) elif self.obs_type == "pixels_agent_pos": - self.features["agent_pos"] = PolicyFeature( - type=FeatureType.STATE, shape=(14,) - ) - self.features["pixels/top"] = PolicyFeature( - type=FeatureType.VISUAL, shape=(480, 640, 3) - ) + self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,)) + self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3)) @property def gym_kwargs(self) -> dict: @@ -108,13 +102,9 @@ class PushtEnv(EnvConfig): def __post_init__(self): if self.obs_type == "pixels_agent_pos": - self.features["pixels"] = PolicyFeature( - type=FeatureType.VISUAL, shape=(384, 384, 3) - ) + self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3)) elif self.obs_type == "environment_state_agent_pos": - self.features["environment_state"] = PolicyFeature( - type=FeatureType.ENV, shape=(16,) - ) + self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,)) @property def gym_kwargs(self) -> dict: @@ -153,9 +143,7 @@ class XarmEnv(EnvConfig): def __post_init__(self): if self.obs_type == "pixels_agent_pos": - self.features["agent_pos"] = PolicyFeature( - type=FeatureType.STATE, shape=(4,) - ) + self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,)) @property def gym_kwargs(self) -> dict: diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 9670e456..c3996d84 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -32,9 +32,7 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig: raise ValueError(f"Policy type '{env_type}' is not available.") -def make_env( - cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False -) -> gym.vector.VectorEnv | None: +def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None: """Makes a gym vector environment according to the config. Args: @@ -58,9 +56,7 @@ def make_env( try: importlib.import_module(package_name) except ModuleNotFoundError as e: - print( - f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`" - ) + print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`") raise e gym_handle = f"{package_name}/{cfg.task}" @@ -68,18 +64,13 @@ def make_env( # batched version of the env that returns an observation of shape (b, c) env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv env = env_cls( - [ - lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) - for _ in range(n_envs) - ] + [lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)] ) return env -def make_maniskill_env( - cfg: DictConfig, n_envs: int | None = None -) -> gym.vector.VectorEnv | None: +def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None: """Make ManiSkill3 gym environment""" from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv @@ -96,9 +87,7 @@ def make_maniskill_env( # state should have the size of 25 # env = ConvertToLeRobotEnv(env, n_envs) # env = PixelWrapper(cfg, env, n_envs) - env._max_episode_steps = env.max_episode_steps = ( - 50 # gym_utils.find_max_episode_steps_value(env) - ) + env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env) env.unwrapped.metadata["render_fps"] = 20 return env @@ -125,11 +114,7 @@ class PixelWrapper(gym.Wrapper): def _get_obs(self, obs): frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2) self._frames.append(frame) - return { - "pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to( - self.env.device - ) - } + return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)} def reset(self, seed): obs, info = self.env.reset() # (seed=seed) @@ -164,9 +149,7 @@ class ConvertToLeRobotEnv(gym.Wrapper): images = torch.concat(images, axis=-1) # flatten the rest of the data which should just be state data - observation = common.flatten_state_dict( - observation, use_torch=True, device=self.base_env.device - ) + observation = common.flatten_state_dict(observation, use_torch=True, device=self.base_env.device) ret = dict() ret["state"] = observation ret["pixels"] = images diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py index 9ca41460..9feb3c39 100644 --- a/lerobot/common/envs/utils.py +++ b/lerobot/common/envs/utils.py @@ -50,9 +50,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten # sanity check that images are channel last _, h, w, c = img.shape - assert c < h and c < w, ( - f"expect channel last images, but instead got {img.shape=}" - ) + assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}" # sanity check that images are uint8 assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" @@ -85,9 +83,7 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]: for key, ft in env_cfg.features.items(): if ft.type is FeatureType.VISUAL: if len(ft.shape) != 3: - raise ValueError( - f"Number of dimensions of {key} != 3 (shape={ft.shape})" - ) + raise ValueError(f"Number of dimensions of {key} != 3 (shape={ft.shape})") shape = get_channel_first_image_shape(ft.shape) feature = PolicyFeature(type=ft.type, shape=shape) diff --git a/lerobot/common/optim/factory.py b/lerobot/common/optim/factory.py index 0854332b..10ff3df7 100644 --- a/lerobot/common/optim/factory.py +++ b/lerobot/common/optim/factory.py @@ -34,13 +34,7 @@ def make_optimizer_and_scheduler( Returns: tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`. """ - params = ( - policy.get_optim_params() - if cfg.use_policy_training_preset - else policy.parameters() - ) + params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters() optimizer = cfg.optimizer.build(params) - lr_scheduler = ( - cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None - ) + lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None return optimizer, lr_scheduler diff --git a/lerobot/common/optim/optimizers.py b/lerobot/common/optim/optimizers.py index bb82cf1d..0cf4124c 100644 --- a/lerobot/common/optim/optimizers.py +++ b/lerobot/common/optim/optimizers.py @@ -102,9 +102,7 @@ def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> No write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS) -def load_optimizer_state( - optimizer: torch.optim.Optimizer, save_dir: Path -) -> torch.optim.Optimizer: +def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer: current_state_dict = optimizer.state_dict() flat_state = load_file(save_dir / OPTIMIZER_STATE) state = unflatten_dict(flat_state) diff --git a/lerobot/common/optim/schedulers.py b/lerobot/common/optim/schedulers.py index f189a647..9934ab6e 100644 --- a/lerobot/common/optim/schedulers.py +++ b/lerobot/common/optim/schedulers.py @@ -36,9 +36,7 @@ class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC): return self.get_choice_name(self.__class__) @abc.abstractmethod - def build( - self, optimizer: Optimizer, num_training_steps: int - ) -> LRScheduler | None: + def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler | None: raise NotImplementedError @@ -79,11 +77,7 @@ class VQBeTSchedulerConfig(LRSchedulerConfig): ) return max( 0.0, - 0.5 - * ( - 1.0 - + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress) - ), + 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)), ) return LambdaLR(optimizer, lr_lambda, -1) @@ -111,9 +105,7 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig): def cosine_decay_schedule(current_step): step = min(current_step, self.num_decay_steps) - cosine_decay = 0.5 * ( - 1 + math.cos(math.pi * step / self.num_decay_steps) - ) + cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps)) alpha = self.decay_lr / self.peak_lr decayed = (1 - alpha) * cosine_decay + alpha return decayed @@ -132,8 +124,6 @@ def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None: def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler: - state_dict = deserialize_json_into_object( - save_dir / SCHEDULER_STATE, scheduler.state_dict() - ) + state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict()) scheduler.load_state_dict(state_dict) return scheduler diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 3c5d30b7..7a5819b7 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -171,9 +171,7 @@ class ACTConfig(PreTrainedConfig): def validate_features(self) -> None: if not self.image_features and not self.env_state_feature: - raise ValueError( - "You must provide at least one image or the environment state among the inputs." - ) + raise ValueError("You must provide at least one image or the environment state among the inputs.") @property def observation_delta_indices(self) -> None: diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 4ab87107..d34c21ac 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -63,9 +63,7 @@ class ACTPolicy(PreTrainedPolicy): config.validate_features() self.config = config - self.normalize_inputs = Normalize( - config.input_features, config.normalization_mapping, dataset_stats - ) + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) self.normalize_targets = Normalize( config.output_features, config.normalization_mapping, dataset_stats ) @@ -76,9 +74,7 @@ class ACTPolicy(PreTrainedPolicy): self.model = ACT(config) if config.temporal_ensemble_coeff is not None: - self.temporal_ensembler = ACTTemporalEnsembler( - config.temporal_ensemble_coeff, config.chunk_size - ) + self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size) self.reset() @@ -122,12 +118,8 @@ class ACTPolicy(PreTrainedPolicy): batch = self.normalize_inputs(batch) if self.config.image_features: - batch = dict( - batch - ) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = [ - batch[key] for key in self.config.image_features - ] + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch["observation.images"] = [batch[key] for key in self.config.image_features] # If we are doing temporal ensembling, do online updates where we keep track of the number of actions # we are ensembling over. @@ -154,19 +146,14 @@ class ACTPolicy(PreTrainedPolicy): """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) if self.config.image_features: - batch = dict( - batch - ) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = [ - batch[key] for key in self.config.image_features - ] + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch["observation.images"] = [batch[key] for key in self.config.image_features] batch = self.normalize_targets(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) l1_loss = ( - F.l1_loss(batch["action"], actions_hat, reduction="none") - * ~batch["action_is_pad"].unsqueeze(-1) + F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) ).mean() loss_dict = {"l1_loss": l1_loss.item()} @@ -176,12 +163,7 @@ class ACTPolicy(PreTrainedPolicy): # KL-divergence per batch element, then take the mean over the batch. # (See App. B of https://arxiv.org/abs/1312.6114 for more details). mean_kld = ( - ( - -0.5 - * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp()) - ) - .sum(-1) - .mean() + (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() ) loss_dict["kld_loss"] = mean_kld.item() loss = l1_loss + mean_kld * self.config.kl_weight @@ -235,9 +217,7 @@ class ACTTemporalEnsembler: ``` """ self.chunk_size = chunk_size - self.ensemble_weights = torch.exp( - -temporal_ensemble_coeff * torch.arange(chunk_size) - ) + self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)) self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0) self.reset() @@ -253,9 +233,7 @@ class ACTTemporalEnsembler: time steps, and pop/return the next batch of actions in the sequence. """ self.ensemble_weights = self.ensemble_weights.to(device=actions.device) - self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to( - device=actions.device - ) + self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device) if self.ensembled_actions is None: # Initializes `self._ensembled_action` to the sequence of actions predicted during the first # time step of the episode. @@ -270,22 +248,12 @@ class ACTTemporalEnsembler: else: # self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute # the online update for those entries. - self.ensembled_actions *= self.ensemble_weights_cumsum[ - self.ensembled_actions_count - 1 - ] - self.ensembled_actions += ( - actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count] - ) - self.ensembled_actions /= self.ensemble_weights_cumsum[ - self.ensembled_actions_count - ] - self.ensembled_actions_count = torch.clamp( - self.ensembled_actions_count + 1, max=self.chunk_size - ) + self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1] + self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count] + self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count] + self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size) # The last action, which has no prior online average, needs to get concatenated onto the end. - self.ensembled_actions = torch.cat( - [self.ensembled_actions, actions[:, -1:]], dim=1 - ) + self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1) self.ensembled_actions_count = torch.cat( [ self.ensembled_actions_count, @@ -356,9 +324,7 @@ class ACT(nn.Module): config.dim_model, ) # Projection layer from the VAE encoder's output to the latent distribution's parameter space. - self.vae_encoder_latent_output_proj = nn.Linear( - config.dim_model, config.latent_dim * 2 - ) + self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2) # Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch # dimension. num_input_token_encoder = 1 + config.chunk_size @@ -366,9 +332,7 @@ class ACT(nn.Module): num_input_token_encoder += 1 self.register_buffer( "vae_encoder_pos_enc", - create_sinusoidal_pos_embedding( - num_input_token_encoder, config.dim_model - ).unsqueeze(0), + create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0), ) # Backbone for image feature extraction. @@ -385,9 +349,7 @@ class ACT(nn.Module): # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final # feature map). # Note: The forward method of this returns a dict: {"feature_map": output}. - self.backbone = IntermediateLayerGetter( - backbone_model, return_layers={"layer4": "feature_map"} - ) + self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) # Transformer (acts as VAE decoder when training with the variational objective). self.encoder = ACTEncoder(config) @@ -416,18 +378,14 @@ class ACT(nn.Module): n_1d_tokens += 1 self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model) if self.config.image_features: - self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d( - config.dim_model // 2 - ) + self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2) # Transformer decoder. # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries). self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model) # Final action regression head on the output of the transformer's decoder. - self.action_head = nn.Linear( - config.dim_model, self.config.action_feature.shape[0] - ) + self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0]) self._reset_parameters() @@ -437,9 +395,7 @@ class ACT(nn.Module): if p.dim() > 1: nn.init.xavier_uniform_(p) - def forward( - self, batch: dict[str, Tensor] - ) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]: + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]: """A forward pass through the Action Chunking Transformer (with optional VAE encoder). `batch` should have the following structure: @@ -475,13 +431,9 @@ class ACT(nn.Module): self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size ) # (B, 1, D) if self.config.robot_state_feature: - robot_state_embed = self.vae_encoder_robot_state_input_proj( - batch["observation.state"] - ) + robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]) robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) - action_embed = self.vae_encoder_action_input_proj( - batch["action"] - ) # (B, S, D) + action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) if self.config.robot_state_feature: vae_encoder_input = [ @@ -526,26 +478,20 @@ class ACT(nn.Module): # When not using the VAE encoder, we set the latent to be all zeros. mu = log_sigma_x2 = None # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer - latent_sample = torch.zeros( - [batch_size, self.config.latent_dim], dtype=torch.float32 - ).to(batch["observation.state"].device) + latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to( + batch["observation.state"].device + ) # Prepare transformer encoder inputs. encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)] - encoder_in_pos_embed = list( - self.encoder_1d_feature_pos_embed.weight.unsqueeze(1) - ) + encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1)) # Robot state token. if self.config.robot_state_feature: - encoder_in_tokens.append( - self.encoder_robot_state_input_proj(batch["observation.state"]) - ) + encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"])) # Environment state token. if self.config.env_state_feature: encoder_in_tokens.append( - self.encoder_env_state_input_proj( - batch["observation.environment_state"] - ) + self.encoder_env_state_input_proj(batch["observation.environment_state"]) ) # Camera observation features and positional embeddings. @@ -556,9 +502,7 @@ class ACT(nn.Module): # For a list of images, the H and W may vary but H*W is constant. for img in batch["observation.images"]: cam_features = self.backbone(img)["feature_map"] - cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to( - dtype=cam_features.dtype - ) + cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) cam_features = self.encoder_img_feat_input_proj(cam_features) # Rearrange features to (sequence, batch, dim). @@ -604,14 +548,8 @@ class ACTEncoder(nn.Module): def __init__(self, config: ACTConfig, is_vae_encoder: bool = False): super().__init__() self.is_vae_encoder = is_vae_encoder - num_layers = ( - config.n_vae_encoder_layers - if self.is_vae_encoder - else config.n_encoder_layers - ) - self.layers = nn.ModuleList( - [ACTEncoderLayer(config) for _ in range(num_layers)] - ) + num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers + self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)]) self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity() def forward( @@ -629,9 +567,7 @@ class ACTEncoder(nn.Module): class ACTEncoderLayer(nn.Module): def __init__(self, config: ACTConfig): super().__init__() - self.self_attn = nn.MultiheadAttention( - config.dim_model, config.n_heads, dropout=config.dropout - ) + self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) # Feed forward layers. self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward) @@ -646,9 +582,7 @@ class ACTEncoderLayer(nn.Module): self.activation = get_activation_fn(config.feedforward_activation) self.pre_norm = config.pre_norm - def forward( - self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None - ) -> Tensor: + def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor: skip = x if self.pre_norm: x = self.norm1(x) @@ -673,9 +607,7 @@ class ACTDecoder(nn.Module): def __init__(self, config: ACTConfig): """Convenience module for running multiple decoder layers followed by normalization.""" super().__init__() - self.layers = nn.ModuleList( - [ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)] - ) + self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]) self.norm = nn.LayerNorm(config.dim_model) def forward( @@ -700,12 +632,8 @@ class ACTDecoder(nn.Module): class ACTDecoderLayer(nn.Module): def __init__(self, config: ACTConfig): super().__init__() - self.self_attn = nn.MultiheadAttention( - config.dim_model, config.n_heads, dropout=config.dropout - ) - self.multihead_attn = nn.MultiheadAttention( - config.dim_model, config.n_heads, dropout=config.dropout - ) + self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) + self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) # Feed forward layers. self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward) @@ -746,9 +674,7 @@ class ACTDecoderLayer(nn.Module): if self.pre_norm: x = self.norm1(x) q = k = self.maybe_add_pos_embed(x, decoder_pos_embed) - x = self.self_attn(q, k, value=x)[ - 0 - ] # select just the output, not the attention weights + x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights x = skip + self.dropout1(x) if self.pre_norm: skip = x @@ -785,14 +711,9 @@ def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tenso """ def get_position_angle_vec(position): - return [ - position / np.power(10000, 2 * (hid_j // 2) / dimension) - for hid_j in range(dimension) - ] + return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)] - sinusoid_table = np.array( - [get_position_angle_vec(pos_i) for pos_i in range(num_positions)] - ) + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)]) sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 return torch.from_numpy(sinusoid_table).float() @@ -837,9 +758,7 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module): x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi inverse_frequency = self._temperature ** ( - 2 - * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) - / self.dimension + 2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension ) x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1) @@ -847,15 +766,9 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module): # Note: this stack then flatten operation results in interleaved sine and cosine terms. # pos_embed_x and pos_embed_y are (1, H, W, C // 2). - pos_embed_x = torch.stack( - (x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1 - ).flatten(3) - pos_embed_y = torch.stack( - (y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1 - ).flatten(3) - pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute( - 0, 3, 1, 2 - ) # (1, C, H, W) + pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3) + pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3) + pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W) return pos_embed diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index 32889ff0..e73c65fe 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -205,16 +205,11 @@ class DiffusionConfig(PreTrainedConfig): def validate_features(self) -> None: if len(self.image_features) == 0 and self.env_state_feature is None: - raise ValueError( - "You must provide at least one image or the environment state among the inputs." - ) + raise ValueError("You must provide at least one image or the environment state among the inputs.") if self.crop_shape is not None: for key, image_ft in self.image_features.items(): - if ( - self.crop_shape[0] > image_ft.shape[1] - or self.crop_shape[1] > image_ft.shape[2] - ): + if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]: raise ValueError( f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} " f"for `crop_shape` and {image_ft.shape} for " diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index cb84292b..1494ee51 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -70,9 +70,7 @@ class DiffusionPolicy(PreTrainedPolicy): config.validate_features() self.config = config - self.normalize_inputs = Normalize( - config.input_features, config.normalization_mapping, dataset_stats - ) + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) self.normalize_targets = Normalize( config.output_features, config.normalization_mapping, dataset_stats ) @@ -99,9 +97,7 @@ class DiffusionPolicy(PreTrainedPolicy): if self.config.image_features: self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps) if self.config.env_state_feature: - self._queues["observation.environment_state"] = deque( - maxlen=self.config.n_obs_steps - ) + self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps) @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: @@ -127,9 +123,7 @@ class DiffusionPolicy(PreTrainedPolicy): """ batch = self.normalize_inputs(batch) if self.config.image_features: - batch = dict( - batch - ) # shallow copy so that adding a key doesn't modify the original + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch["observation.images"] = torch.stack( [batch[key] for key in self.config.image_features], dim=-4 ) @@ -138,11 +132,7 @@ class DiffusionPolicy(PreTrainedPolicy): if len(self._queues["action"]) == 0: # stack n latest observations from the queue - batch = { - k: torch.stack(list(self._queues[k]), dim=1) - for k in batch - if k in self._queues - } + batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} actions = self.diffusion.generate_actions(batch) # TODO(rcadene): make above methods return output dictionary? @@ -157,9 +147,7 @@ class DiffusionPolicy(PreTrainedPolicy): """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) if self.config.image_features: - batch = dict( - batch - ) # shallow copy so that adding a key doesn't modify the original + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch["observation.images"] = torch.stack( [batch[key] for key in self.config.image_features], dim=-4 ) @@ -201,9 +189,7 @@ class DiffusionModel(nn.Module): if self.config.env_state_feature: global_cond_dim += self.config.env_state_feature.shape[0] - self.unet = DiffusionConditionalUnet1d( - config, global_cond_dim=global_cond_dim * config.n_obs_steps - ) + self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps) self.noise_scheduler = _make_noise_scheduler( config.noise_scheduler_type, @@ -249,9 +235,7 @@ class DiffusionModel(nn.Module): global_cond=global_cond, ) # Compute previous image: x_t -> x_t-1 - sample = self.noise_scheduler.step( - model_output, t, sample, generator=generator - ).prev_sample + sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample return sample @@ -263,15 +247,11 @@ class DiffusionModel(nn.Module): if self.config.image_features: if self.config.use_separate_rgb_encoder_per_camera: # Combine batch and sequence dims while rearranging to make the camera index dimension first. - images_per_camera = einops.rearrange( - batch["observation.images"], "b s n ... -> n (b s) ..." - ) + images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...") img_features_list = torch.cat( [ encoder(images) - for encoder, images in zip( - self.rgb_encoder, images_per_camera, strict=True - ) + for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True) ] ) # Separate batch and sequence dims back out. The camera index dim gets absorbed into the @@ -285,9 +265,7 @@ class DiffusionModel(nn.Module): else: # Combine batch, sequence, and "which camera" dims before passing to shared encoder. img_features = self.rgb_encoder( - einops.rearrange( - batch["observation.images"], "b s n ... -> (b s n) ..." - ) + einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...") ) # Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the # feature dim (effectively concatenating the camera features). @@ -381,9 +359,7 @@ class DiffusionModel(nn.Module): elif self.config.prediction_type == "sample": target = batch["action"] else: - raise ValueError( - f"Unsupported prediction type {self.config.prediction_type}" - ) + raise ValueError(f"Unsupported prediction type {self.config.prediction_type}") loss = F.mse_loss(pred, target, reduction="none") @@ -443,9 +419,7 @@ class SpatialSoftmax(nn.Module): # we could use torch.linspace directly but that seems to behave slightly differently than numpy # and causes a small degradation in pc_success of pre-trained models. - pos_x, pos_y = np.meshgrid( - np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h) - ) + pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)) pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float() pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float() # register as buffer so it's moved to the correct device. @@ -487,9 +461,7 @@ class DiffusionRgbEncoder(nn.Module): # Always use center crop for eval self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape) if config.crop_is_random: - self.maybe_random_crop = torchvision.transforms.RandomCrop( - config.crop_shape - ) + self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape) else: self.maybe_random_crop = self.center_crop else: @@ -510,9 +482,7 @@ class DiffusionRgbEncoder(nn.Module): self.backbone = _replace_submodules( root_module=self.backbone, predicate=lambda x: isinstance(x, nn.BatchNorm2d), - func=lambda x: nn.GroupNorm( - num_groups=x.num_features // 16, num_channels=x.num_features - ), + func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features), ) # Set up pooling and final layers. @@ -523,15 +493,11 @@ class DiffusionRgbEncoder(nn.Module): # Note: we have a check in the config class to make sure all images have the same shape. images_shape = next(iter(config.image_features.values())).shape - dummy_shape_h_w = ( - config.crop_shape if config.crop_shape is not None else images_shape[1:] - ) + dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:] dummy_shape = (1, images_shape[0], *dummy_shape_h_w) feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:] - self.pool = SpatialSoftmax( - feature_map_shape, num_kp=config.spatial_softmax_num_keypoints - ) + self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints) self.feature_dim = config.spatial_softmax_num_keypoints * 2 self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim) self.relu = nn.ReLU() @@ -573,11 +539,7 @@ def _replace_submodules( if predicate(root_module): return func(root_module) - replace_list = [ - k.split(".") - for k, m in root_module.named_modules(remove_duplicate=True) - if predicate(m) - ] + replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] for *parents, k in replace_list: parent_module = root_module if len(parents) > 0: @@ -592,9 +554,7 @@ def _replace_submodules( else: setattr(parent_module, k, tgt_module) # verify that all BN are replaced - assert not any( - predicate(m) for _, m in root_module.named_modules(remove_duplicate=True) - ) + assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)) return root_module @@ -622,9 +582,7 @@ class DiffusionConv1dBlock(nn.Module): super().__init__() self.block = nn.Sequential( - nn.Conv1d( - inp_channels, out_channels, kernel_size, padding=kernel_size // 2 - ), + nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), nn.GroupNorm(n_groups, out_channels), nn.Mish(), ) @@ -647,13 +605,9 @@ class DiffusionConditionalUnet1d(nn.Module): # Encoder for the diffusion timestep. self.diffusion_step_encoder = nn.Sequential( DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim), - nn.Linear( - config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4 - ), + nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4), nn.Mish(), - nn.Linear( - config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim - ), + nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim), ) # The FiLM conditioning dimension. @@ -678,16 +632,10 @@ class DiffusionConditionalUnet1d(nn.Module): self.down_modules.append( nn.ModuleList( [ - DiffusionConditionalResidualBlock1d( - dim_in, dim_out, **common_res_block_kwargs - ), - DiffusionConditionalResidualBlock1d( - dim_out, dim_out, **common_res_block_kwargs - ), + DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs), + DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs), # Downsample as long as it is not the last block. - nn.Conv1d(dim_out, dim_out, 3, 2, 1) - if not is_last - else nn.Identity(), + nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(), ] ) ) @@ -716,24 +664,16 @@ class DiffusionConditionalUnet1d(nn.Module): nn.ModuleList( [ # dim_in * 2, because it takes the encoder's skip connection as well - DiffusionConditionalResidualBlock1d( - dim_in * 2, dim_out, **common_res_block_kwargs - ), - DiffusionConditionalResidualBlock1d( - dim_out, dim_out, **common_res_block_kwargs - ), + DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs), + DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs), # Upsample as long as it is not the last block. - nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) - if not is_last - else nn.Identity(), + nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(), ] ) ) self.final_conv = nn.Sequential( - DiffusionConv1dBlock( - config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size - ), + DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size), nn.Conv1d(config.down_dims[0], config.action_feature.shape[0], 1), ) @@ -801,23 +741,17 @@ class DiffusionConditionalResidualBlock1d(nn.Module): self.use_film_scale_modulation = use_film_scale_modulation self.out_channels = out_channels - self.conv1 = DiffusionConv1dBlock( - in_channels, out_channels, kernel_size, n_groups=n_groups - ) + self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups) # FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale. cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels)) - self.conv2 = DiffusionConv1dBlock( - out_channels, out_channels, kernel_size, n_groups=n_groups - ) + self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups) # A final convolution for dimension matching the residual (if needed). self.residual_conv = ( - nn.Conv1d(in_channels, out_channels, 1) - if in_channels != out_channels - else nn.Identity() + nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() ) def forward(self, x: Tensor, cond: Tensor) -> Tensor: diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 705c24b4..455f5c67 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -111,9 +111,7 @@ def make_policy( PreTrainedPolicy: _description_ """ if bool(ds_meta) == bool(env_cfg): - raise ValueError( - "Either one of a dataset metadata or a sim env must be provided." - ) + raise ValueError("Either one of a dataset metadata or a sim env must be provided.") # NOTE: Currently, if you try to run vqbet with mps backend, you'll get this error. # TODO(aliberts, rcadene): Implement a check_backend_compatibility in policies? @@ -143,12 +141,8 @@ def make_policy( ) features = env_to_policy_features(env_cfg) - cfg.output_features = { - key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION - } - cfg.input_features = { - key: ft for key, ft in features.items() if key not in cfg.output_features - } + cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} + cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features} kwargs["config"] = cfg if cfg.pretrained_path: diff --git a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py index eb023f9f..18c30493 100644 --- a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py +++ b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py @@ -7,9 +7,7 @@ from torch import Tensor, nn from .configuration_classifier import ClassifierConfig -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" -) +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) @@ -53,9 +51,7 @@ class Classifier( super().__init__() self.config = config # self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True) - encoder = AutoModel.from_pretrained( - self.config.model_name, trust_remote_code=True - ) + encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True) # Extract vision model if we're given a multimodal model if hasattr(encoder, "vision_model"): logging.info("Multimodal model detected - using vision encoder only") @@ -81,9 +77,7 @@ class Classifier( self.feature_dim = self.encoder.fc.in_features self.encoder = nn.Sequential(*list(self.encoder.children())[:-1]) elif hasattr(self.encoder.config, "hidden_sizes"): - self.feature_dim = self.encoder.config.hidden_sizes[ - -1 - ] # Last channel dimension + self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension else: raise ValueError("Unsupported CNN architecture") @@ -103,9 +97,7 @@ class Classifier( if hasattr(self.encoder.config, "hidden_size"): input_dim = self.encoder.config.hidden_size else: - raise ValueError( - "Unsupported transformer architecture since hidden_size is not found" - ) + raise ValueError("Unsupported transformer architecture since hidden_size is not found") self.classifier_head = nn.Sequential( nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim), @@ -141,10 +133,7 @@ class Classifier( return features else: # Transformer models outputs = self.encoder(processed) - if ( - hasattr(outputs, "pooler_output") - and outputs.pooler_output is not None - ): + if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None: return outputs.pooler_output return outputs.last_hidden_state[:, 0, :] @@ -160,9 +149,7 @@ class Classifier( else: probabilities = torch.softmax(logits, dim=-1) - return ClassifierOutput( - logits=logits, probabilities=probabilities, hidden_states=encoder_outputs - ) + return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs) def predict_reward(self, x, threshold=0.6): if self.config.num_classes == 2: diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index 46428700..012c854d 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -82,43 +82,25 @@ def create_stats_buffers( if stats: if isinstance(stats[key]["mean"], np.ndarray): if norm_mode is NormalizationMode.MEAN_STD: - buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to( - dtype=torch.float32 - ) - buffer["std"].data = torch.from_numpy(stats[key]["std"]).to( - dtype=torch.float32 - ) + buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32) + buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32) elif norm_mode is NormalizationMode.MIN_MAX: - buffer["min"].data = torch.from_numpy(stats[key]["min"]).to( - dtype=torch.float32 - ) - buffer["max"].data = torch.from_numpy(stats[key]["max"]).to( - dtype=torch.float32 - ) + buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32) + buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32) elif isinstance(stats[key]["mean"], torch.Tensor): # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated # tensors anywhere (for example, when we use the same stats for normalization and # unnormalization). See the logic here # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. if norm_mode is NormalizationMode.MEAN_STD: - buffer["mean"].data = ( - stats[key]["mean"].clone().to(dtype=torch.float32) - ) - buffer["std"].data = ( - stats[key]["std"].clone().to(dtype=torch.float32) - ) + buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32) + buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32) elif norm_mode is NormalizationMode.MIN_MAX: - buffer["min"].data = ( - stats[key]["min"].clone().to(dtype=torch.float32) - ) - buffer["max"].data = ( - stats[key]["max"].clone().to(dtype=torch.float32) - ) + buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32) + buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32) else: type_ = type(stats[key]["mean"]) - raise ValueError( - f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead." - ) + raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.") stats_buffers[key] = buffer return stats_buffers diff --git a/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py b/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py index dc729eea..6bd7c91f 100644 --- a/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py +++ b/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py @@ -44,9 +44,7 @@ def main(): else: dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human" - ckpt_torch_dir = ( - Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch" - ) + ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch" ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}" save_dir = Path(f"../openpi/data/{model_name}/save") @@ -72,9 +70,7 @@ def main(): # Create LeRobot batch from Jax batch = {} for cam_key, uint_chw_array in example["images"].items(): - batch[f"observation.images.{cam_key}"] = ( - torch.from_numpy(uint_chw_array) / 255.0 - ) + batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0 batch["observation.state"] = torch.from_numpy(example["state"]) batch["action"] = torch.from_numpy(outputs["actions"]) batch["task"] = example["prompt"] diff --git a/lerobot/common/policies/pi0/conversion_scripts/conversion_utils.py b/lerobot/common/policies/pi0/conversion_scripts/conversion_utils.py index 6b8406a9..8835da31 100644 --- a/lerobot/common/policies/pi0/conversion_scripts/conversion_utils.py +++ b/lerobot/common/policies/pi0/conversion_scripts/conversion_utils.py @@ -54,9 +54,7 @@ def get_paligemma_config(precision: str): "projector_hidden_act": "gelu_fast", "vision_use_head": False, } - final_config = PaliGemmaConfig( - text_config=text_config, vision_config=vision_config, **config - ) + final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config) return final_config diff --git a/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py b/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py index 9fa00fbf..bb51361b 100644 --- a/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py +++ b/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py @@ -322,9 +322,7 @@ def update_keys_with_prefix(d: dict, prefix: str) -> dict: return {f"{prefix}{key}": value for key, value in d.items()} -def convert_pi0_checkpoint( - checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str -): +def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str): # Break down orbax ckpts - they are in OCDBT initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir) # process projection params @@ -384,9 +382,7 @@ def convert_pi0_checkpoint( # gemma_config=gemma_config, paligemma_config=paligemma_config) pi0_model = PI0Policy(pi0_config) - paligemma_params = update_keys_with_prefix( - paligemma_params, "model.paligemma_with_expert." - ) + paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.") gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.") projection_params = update_keys_with_prefix(projection_params, "model.") diff --git a/lerobot/common/policies/pi0/modeling_pi0.py b/lerobot/common/policies/pi0/modeling_pi0.py index dd01a18c..6c60d45d 100644 --- a/lerobot/common/policies/pi0/modeling_pi0.py +++ b/lerobot/common/policies/pi0/modeling_pi0.py @@ -193,9 +193,7 @@ def aloha_gripper_to_angular(value): # This is the inverse of the angular to linear transformation inside the Interbotix code. def linear_to_radian(linear_position, arm_length, horn_radius): - value = (horn_radius**2 + linear_position**2 - arm_length**2) / ( - 2 * horn_radius * linear_position - ) + value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position) return safe_arcsin(value) # The constants are taken from the Interbotix code. @@ -246,9 +244,7 @@ class PI0Policy(PreTrainedPolicy): super().__init__(config) config.validate_features() self.config = config - self.normalize_inputs = Normalize( - config.input_features, config.normalization_mapping, dataset_stats - ) + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) self.normalize_targets = Normalize( config.output_features, config.normalization_mapping, dataset_stats ) @@ -256,9 +252,7 @@ class PI0Policy(PreTrainedPolicy): config.output_features, config.normalization_mapping, dataset_stats ) - self.language_tokenizer = AutoTokenizer.from_pretrained( - "google/paligemma-3b-pt-224" - ) + self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") self.model = PI0FlowMatching(config) self.reset() @@ -271,9 +265,7 @@ class PI0Policy(PreTrainedPolicy): return self.parameters() @torch.no_grad - def select_action( - self, batch: dict[str, Tensor], noise: Tensor | None = None - ) -> Tensor: + def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: """Select a single action given environment observations. This method wraps `select_actions` in order to return one action at a time for execution in the @@ -312,9 +304,7 @@ class PI0Policy(PreTrainedPolicy): self._action_queue.extend(actions.transpose(0, 1)) return self._action_queue.popleft() - def forward( - self, batch: dict[str, Tensor], noise=None, time=None - ) -> tuple[Tensor, dict[str, Tensor]]: + def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]: """Do a full training forward pass to compute the loss""" if self.config.adapt_to_pi_aloha: batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) @@ -330,9 +320,7 @@ class PI0Policy(PreTrainedPolicy): actions_is_pad = batch.get("action_is_pad") loss_dict = {} - losses = self.model.forward( - images, img_masks, lang_tokens, lang_masks, state, actions, noise, time - ) + losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) loss_dict["losses_after_forward"] = losses.clone() if actions_is_pad is not None: @@ -359,9 +347,7 @@ class PI0Policy(PreTrainedPolicy): img_masks = [] present_img_keys = [key for key in self.config.image_features if key in batch] - missing_img_keys = [ - key for key in self.config.image_features if key not in batch - ] + missing_img_keys = [key for key in self.config.image_features if key not in batch] if len(present_img_keys) == 0: raise ValueError( @@ -373,9 +359,7 @@ class PI0Policy(PreTrainedPolicy): img = batch[key] if self.config.resize_imgs_with_padding is not None: - img = resize_with_pad( - img, *self.config.resize_imgs_with_padding, pad_value=0 - ) + img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0) # Normalize from range [0,1] to [-1,1] as expacted by siglip img = img * 2.0 - 1.0 @@ -414,9 +398,7 @@ class PI0Policy(PreTrainedPolicy): return_tensors="pt", ) lang_tokens = tokenized_prompt["input_ids"].to(device=device) - lang_masks = tokenized_prompt["attention_mask"].to( - device=device, dtype=torch.bool - ) + lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool) return lang_tokens, lang_masks @@ -435,9 +417,7 @@ class PI0Policy(PreTrainedPolicy): actions[:, :, motor_idx] *= -1 # Reverse the gripper transformation that is being applied by the Aloha runtime. for motor_idx in [6, 13]: - actions[:, :, motor_idx] = aloha_gripper_from_angular( - actions[:, :, motor_idx] - ) + actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx]) return actions def _pi_aloha_encode_actions_inv(self, actions): @@ -446,9 +426,7 @@ class PI0Policy(PreTrainedPolicy): actions[:, :, motor_idx] *= -1 # Reverse the gripper transformation that is being applied by the Aloha runtime. for motor_idx in [6, 13]: - actions[:, :, motor_idx] = aloha_gripper_from_angular_inv( - actions[:, :, motor_idx] - ) + actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx]) return actions def prepare_state(self, batch): @@ -498,25 +476,15 @@ class PI0FlowMatching(nn.Module): train_expert_only=self.config.train_expert_only, attention_implementation=self.config.attention_implementation, ) - self.paligemma_with_expert = PaliGemmaWithExpertModel( - paligemma_with_export_config - ) + self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config) # Projections are float32 self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width) - self.action_in_proj = nn.Linear( - self.config.max_action_dim, self.config.proj_width - ) - self.action_out_proj = nn.Linear( - self.config.proj_width, self.config.max_action_dim - ) + self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width) + self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim) - self.action_time_mlp_in = nn.Linear( - self.config.proj_width * 2, self.config.proj_width - ) - self.action_time_mlp_out = nn.Linear( - self.config.proj_width, self.config.proj_width - ) + self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width) + self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width) self.set_requires_grad() @@ -560,9 +528,7 @@ class PI0FlowMatching(nn.Module): # Normalize image embeddings img_emb_dim = img_emb.shape[-1] - img_emb = img_emb * torch.tensor( - img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device - ) + img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device) bsize, num_img_embs = img_emb.shape[:2] img_mask = img_mask[:, None].expand(bsize, num_img_embs) @@ -637,9 +603,7 @@ class PI0FlowMatching(nn.Module): embs.append(action_time_emb) bsize, action_time_dim = action_time_emb.shape[:2] - action_time_mask = torch.ones( - bsize, action_time_dim, dtype=torch.bool, device=device - ) + action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device) pad_masks.append(action_time_mask) # Set attention masks so that image, language and state inputs do not attend to action tokens @@ -677,9 +641,7 @@ class PI0FlowMatching(nn.Module): prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( images, img_masks, lang_tokens, lang_masks ) - suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix( - state, x_t, time - ) + suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time) pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) @@ -703,9 +665,7 @@ class PI0FlowMatching(nn.Module): losses = F.mse_loss(u_t, v_t, reduction="none") return losses - def sample_actions( - self, images, img_masks, lang_tokens, lang_masks, state, noise=None - ) -> Tensor: + def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor: """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" bsize = state.shape[0] device = state.device @@ -763,16 +723,12 @@ class PI0FlowMatching(nn.Module): timestep, ): """Apply one denoising step of the noise `x_t` at a given timestep.""" - suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix( - state, x_t, timestep - ) + suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep) suffix_len = suffix_pad_masks.shape[1] batch_size = prefix_pad_masks.shape[0] prefix_len = prefix_pad_masks.shape[1] - prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand( - batch_size, suffix_len, prefix_len - ) + prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len) suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) diff --git a/lerobot/common/policies/pi0/paligemma_with_expert.py b/lerobot/common/policies/pi0/paligemma_with_expert.py index fc0ae065..55eea148 100644 --- a/lerobot/common/policies/pi0/paligemma_with_expert.py +++ b/lerobot/common/policies/pi0/paligemma_with_expert.py @@ -39,13 +39,9 @@ def apply_rope(x, positions, max_wavelength=10_000): dtype = x.dtype x = x.to(torch.float32) - freq_exponents = (2.0 / x.shape[-1]) * torch.arange( - d_half, dtype=torch.float32, device=device - ) + freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device) timescale = max_wavelength**freq_exponents - radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to( - torch.float32 - ) + radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32) radians = radians[..., None, :] @@ -178,9 +174,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel): def __init__(self, config: PaliGemmaWithExpertConfig): super().__init__(config=config) self.config = config - self.paligemma = PaliGemmaForConditionalGeneration( - config=config.paligemma_config - ) + self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config) self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config) # Remove unused embed_tokens self.gemma_expert.model.embed_tokens = None @@ -297,9 +291,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel): # so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach # the max len, then we (for instance) double the cache size. This implementation already exists # in `transformers`. (molbap) - key_states = torch.cat( - [past_key_values[layer_idx]["key_states"], key_states], dim=1 - ) + key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1) value_states = torch.cat( [past_key_values[layer_idx]["value_states"], value_states], dim=1, @@ -392,9 +384,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel): value_states, ): num_att_heads = self.config.paligemma_config.text_config.num_attention_heads - num_key_value_heads = ( - self.config.paligemma_config.text_config.num_key_value_heads - ) + num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads num_key_value_groups = num_att_heads // num_key_value_heads # query_states: batch_size, sequence_length, num_att_head, head_dim @@ -442,9 +432,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel): att_weights *= head_dim**-0.5 big_neg = -2.3819763e38 # See gemma/modules.py - masked_att_weights = torch.where( - attention_mask[:, None, :, :], att_weights, big_neg - ) + masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg) probs = nn.functional.softmax(masked_att_weights, dim=-1) probs = probs.to(dtype=value_states.dtype) @@ -456,8 +444,6 @@ class PaliGemmaWithExpertModel(PreTrainedModel): att_output = att_output.permute(0, 2, 1, 3) # we use -1 because sequence length can change - att_output = att_output.reshape( - batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim - ) + att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim) return att_output diff --git a/lerobot/common/policies/pretrained.py b/lerobot/common/policies/pretrained.py index 75d077b6..da4ef157 100644 --- a/lerobot/common/policies/pretrained.py +++ b/lerobot/common/policies/pretrained.py @@ -71,9 +71,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): def _save_pretrained(self, save_directory: Path) -> None: self.config._save_pretrained(save_directory) model_to_save = self.module if hasattr(self, "module") else self - save_model_as_safetensor( - model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE) - ) + save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE)) @classmethod def from_pretrained( @@ -112,9 +110,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): if os.path.isdir(model_id): print("Loading weights from local directory") model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE) - policy = cls._load_as_safetensor( - instance, model_file, config.device, strict - ) + policy = cls._load_as_safetensor(instance, model_file, config.device, strict) else: try: model_file = hf_hub_download( @@ -128,9 +124,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): token=token, local_files_only=local_files_only, ) - policy = cls._load_as_safetensor( - instance, model_file, config.device, strict - ) + policy = cls._load_as_safetensor(instance, model_file, config.device, strict) except HfHubHTTPError as e: raise FileNotFoundError( f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}" @@ -141,12 +135,8 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): return policy @classmethod - def _load_as_safetensor( - cls, model: T, model_file: str, map_location: str, strict: bool - ) -> T: - if packaging.version.parse(safetensors.__version__) < packaging.version.parse( - "0.4.3" - ): + def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T: + if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"): load_model_as_safetensor(model, model_file, strict=strict) if map_location != "cpu": logging.warning( @@ -157,9 +147,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): ) model.to(map_location) else: - safetensors.torch.load_model( - model, model_file, strict=strict, device=map_location - ) + safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) return model # def generate_model_card(self, *args, **kwargs) -> ModelCard: diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 3f1a7fbb..8d98ed40 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -48,9 +48,7 @@ class SACConfig: "observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]}, } ) - output_normalization_modes: dict[str, str] = field( - default_factory=lambda: {"action": "min_max"} - ) + output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"}) output_normalization_params: dict[str, dict[str, list[float]]] = field( default_factory=lambda: { "action": {"min": [-1, -1], "max": [1, 1]}, diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index ffa91468..7cf9b8d6 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -18,8 +18,8 @@ # TODO: (1) better device management import math -from typing import Callable, Optional, Tuple, Union, Dict, List from pathlib import Path +from typing import Callable, Dict, List, Optional, Tuple, Union import einops import numpy as np @@ -124,17 +124,13 @@ class SACPolicy( self.actor = Policy( encoder=encoder_actor, - network=MLP( - input_dim=encoder_actor.output_dim, **config.actor_network_kwargs - ), + network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs), action_dim=config.output_shapes["action"][0], encoder_is_shared=config.shared_encoder, **config.policy_kwargs, ) if config.target_entropy is None: - config.target_entropy = ( - -np.prod(config.output_shapes["action"][0]) / 2 - ) # (-dim(A)/2) + config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2) # TODO (azouitine): Handle the case where the temparameter is a fixed # TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise @@ -146,10 +142,11 @@ class SACPolicy( def _save_pretrained(self, save_directory): """Custom save method to handle TensorDict properly""" - import os import json + import os from dataclasses import asdict - from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME + + from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE from safetensors.torch import save_model save_model(self, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE)) @@ -177,12 +174,14 @@ class SACPolicy( **model_kwargs, ) -> "SACPolicy": """Custom load method to handle loading SAC policy from saved files""" - import os import json + import os from pathlib import Path + from huggingface_hub import hf_hub_download - from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME + from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE from safetensors.torch import load_model + from lerobot.common.policies.sac.configuration_sac import SACConfig # Check if model_id is a local path or a hub model ID @@ -302,14 +301,10 @@ class SACPolicy( ) -> Tensor: self.temperature = self.log_alpha.exp().item() with torch.no_grad(): - next_action_preds, next_log_probs, _ = self.actor( - next_observations, next_observation_features - ) + next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features) # TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way - next_action_preds = self.unnormalize_outputs({"action": next_action_preds})[ - "action" - ] + next_action_preds = self.unnormalize_outputs({"action": next_action_preds})["action"] # 2- compute q targets q_targets = self.critic_forward( @@ -353,21 +348,15 @@ class SACPolicy( ).sum() return critics_loss - def compute_loss_temperature( - self, observations, observation_features: Tensor | None = None - ) -> Tensor: + def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor: """Compute the temperature loss""" # calculate temperature loss with torch.no_grad(): _, log_probs, _ = self.actor(observations, observation_features) - temperature_loss = ( - -self.log_alpha.exp() * (log_probs + self.config.target_entropy) - ).mean() + temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean() return temperature_loss - def compute_loss_actor( - self, observations, observation_features: Tensor | None = None - ) -> Tensor: + def compute_loss_actor(self, observations, observation_features: Tensor | None = None) -> Tensor: self.temperature = self.log_alpha.exp().item() actions_pi, log_probs, _ = self.actor(observations, observation_features) @@ -408,11 +397,7 @@ class MLP(nn.Module): if dropout_rate is not None and dropout_rate > 0: layers.append(nn.Dropout(p=dropout_rate)) layers.append(nn.LayerNorm(hidden_dims[0])) - layers.append( - activations - if isinstance(activations, nn.Module) - else getattr(nn, activations)() - ) + layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)()) # Rest of the layers for i in range(1, len(hidden_dims)): @@ -424,11 +409,7 @@ class MLP(nn.Module): layers.append(nn.LayerNorm(hidden_dims[i])) # If we're at the final layer and a final activation is specified, use it - if ( - i + 1 == len(hidden_dims) - and activate_final - and final_activation is not None - ): + if i + 1 == len(hidden_dims) and activate_final and final_activation is not None: layers.append( final_activation if isinstance(final_activation, nn.Module) @@ -436,9 +417,7 @@ class MLP(nn.Module): ) else: layers.append( - activations - if isinstance(activations, nn.Module) - else getattr(nn, activations)() + activations if isinstance(activations, nn.Module) else getattr(nn, activations)() ) self.net = nn.Sequential(*layers) @@ -639,15 +618,11 @@ class Policy(nn.Module): # Compute standard deviations if self.fixed_std is None: log_std = self.std_layer(outputs) - assert not torch.isnan(log_std).any(), ( - "[ERROR] log_std became NaN after std_layer!" - ) + assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!" if self.use_tanh_squash: log_std = torch.tanh(log_std) - log_std = self.log_std_min + 0.5 * ( - self.log_std_max - self.log_std_min - ) * (log_std + 1.0) + log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0) else: log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) else: @@ -660,9 +635,7 @@ class Policy(nn.Module): if self.use_tanh_squash: actions = torch.tanh(x_t) - log_probs -= torch.log( - (1 - actions.pow(2)) + 1e-6 - ) # Adjust log-probs for Tanh + log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh else: actions = x_t # No Tanh; raw Gaussian sample @@ -709,9 +682,7 @@ class SACObservationEncoder(nn.Module): freeze_image_encoder(self.image_enc_layers) else: self.parameters_to_optimize += list(self.image_enc_layers.parameters()) - self.all_image_keys = [ - k for k in config.input_shapes if k.startswith("observation.image") - ] + self.all_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] if "observation.state" in config.input_shapes: self.state_enc_layers = nn.Sequential( @@ -738,9 +709,7 @@ class SACObservationEncoder(nn.Module): self.aggregation_size += config.latent_dim self.parameters_to_optimize += list(self.env_state_enc_layers.parameters()) - self.aggregation_layer = nn.Linear( - in_features=self.aggregation_size, out_features=config.latent_dim - ) + self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim) self.parameters_to_optimize += list(self.aggregation_layer.parameters()) def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: @@ -753,19 +722,13 @@ class SACObservationEncoder(nn.Module): obs_dict = self.input_normalization(obs_dict) # Batch all images along the batch dimension, then encode them. if len(self.all_image_keys) > 0: - images_batched = torch.cat( - [obs_dict[key] for key in self.all_image_keys], dim=0 - ) + images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0) images_batched = self.image_enc_layers(images_batched) - embeddings_chunks = torch.chunk( - images_batched, dim=0, chunks=len(self.all_image_keys) - ) + embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys)) feat.extend(embeddings_chunks) if "observation.environment_state" in self.config.input_shapes: - feat.append( - self.env_state_enc_layers(obs_dict["observation.environment_state"]) - ) + feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) if "observation.state" in self.config.input_shapes: feat.append(self.state_enc_layers(obs_dict["observation.state"])) @@ -833,9 +796,7 @@ class PretrainedImageEncoder(nn.Module): def __init__(self, config): super().__init__() - self.image_enc_layers, self.image_enc_out_shape = ( - self._load_pretrained_vision_encoder(config) - ) + self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config) self.image_enc_proj = nn.Sequential( nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim), nn.LayerNorm(config.latent_dim), @@ -846,21 +807,15 @@ class PretrainedImageEncoder(nn.Module): """Set up CNN encoder""" from transformers import AutoModel - self.image_enc_layers = AutoModel.from_pretrained( - config.vision_encoder_name, trust_remote_code=True - ) + self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True) # self.image_enc_layers.pooler = Identity() if hasattr(self.image_enc_layers.config, "hidden_sizes"): - self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[ - -1 - ] # Last channel dimension + self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension elif hasattr(self.image_enc_layers, "fc"): self.image_enc_out_shape = self.image_enc_layers.fc.in_features else: - raise ValueError( - "Unsupported vision encoder architecture, make sure you are using a CNN" - ) + raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN") return self.image_enc_layers, self.image_enc_out_shape def forward(self, x): @@ -896,9 +851,7 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict: for key, value in inner_dict.items(): converted_params[outer_key][key] = torch.tensor(value) if "image" in outer_key: - converted_params[outer_key][key] = converted_params[outer_key][ - key - ].view(3, 1, 1) + converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1) return converted_params diff --git a/lerobot/common/policies/tdmpc/configuration_tdmpc.py b/lerobot/common/policies/tdmpc/configuration_tdmpc.py index 241e8d80..3fce01df 100644 --- a/lerobot/common/policies/tdmpc/configuration_tdmpc.py +++ b/lerobot/common/policies/tdmpc/configuration_tdmpc.py @@ -183,13 +183,9 @@ class TDMPCConfig(PreTrainedConfig): "If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1." ) if not self.use_mpc: - raise ValueError( - "If `n_action_steps > 1`, `use_mpc` must be set to `True`." - ) + raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.") if self.n_action_steps > self.horizon: - raise ValueError( - "`n_action_steps` must be less than or equal to `horizon`." - ) + raise ValueError("`n_action_steps` must be less than or equal to `horizon`.") def get_optimizer_preset(self) -> AdamConfig: return AdamConfig(lr=self.optimizer_lr) @@ -209,9 +205,7 @@ class TDMPCConfig(PreTrainedConfig): if image_ft.shape[-2] != image_ft.shape[-1]: # TODO(alexander-soare): This limitation is solely because of code in the random shift # augmentation. It should be able to be removed. - raise ValueError( - f"Only square images are handled now. Got image shape {image_ft.shape}." - ) + raise ValueError(f"Only square images are handled now. Got image shape {image_ft.shape}.") @property def observation_delta_indices(self) -> list: diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index a81d23bd..6686e1ef 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -83,9 +83,7 @@ class TDMPCPolicy(PreTrainedPolicy): config.validate_features() self.config = config - self.normalize_inputs = Normalize( - config.input_features, config.normalization_mapping, dataset_stats - ) + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) self.normalize_targets = Normalize( config.output_features, config.normalization_mapping, dataset_stats ) @@ -110,9 +108,7 @@ class TDMPCPolicy(PreTrainedPolicy): """ self._queues = { "observation.state": deque(maxlen=1), - "action": deque( - maxlen=max(self.config.n_action_steps, self.config.n_action_repeats) - ), + "action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)), } if self.config.image_features: self._queues["observation.image"] = deque(maxlen=1) @@ -127,9 +123,7 @@ class TDMPCPolicy(PreTrainedPolicy): """Select a single action given environment observations.""" batch = self.normalize_inputs(batch) if self.config.image_features: - batch = dict( - batch - ) # shallow copy so that adding a key doesn't modify the original + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch["observation.image"] = batch[next(iter(self.config.image_features))] self._queues = populate_queues(self._queues, batch) @@ -232,47 +226,35 @@ class TDMPCPolicy(PreTrainedPolicy): self.config.action_feature.shape[0], device=std.device, ) - gaussian_actions = torch.clamp( - mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1 - ) + gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1) # Compute elite actions. actions = torch.cat([gaussian_actions, pi_actions], dim=1) value = self.estimate_value(z, actions).nan_to_num_(0) - elite_idxs = torch.topk( - value, self.config.n_elites, dim=0 - ).indices # (n_elites, batch) + elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch) elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch) # (horizon, n_elites, batch, action_dim) - elite_actions = actions.take_along_dim( - einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1 - ) + elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1) # Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites. max_value = elite_value.max(0, keepdim=True)[0] # (1, batch) # The weighting is a softmax over trajectory values. Note that this is not the same as the usage # of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This # makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²). - score = torch.exp( - self.config.elite_weighting_temperature * (elite_value - max_value) - ) + score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value)) score /= score.sum(axis=0, keepdim=True) # (horizon, batch, action_dim) - _mean = torch.sum( - einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1 - ) + _mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) _std = torch.sqrt( torch.sum( einops.rearrange(score, "n b -> n b 1") - * (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) - ** 2, + * (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2, dim=1, ) ) # Update mean with an exponential moving average, and std with a direct replacement. mean = ( - self.config.gaussian_mean_momentum * mean - + (1 - self.config.gaussian_mean_momentum) * _mean + self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean ) std = _std.clamp_(self.config.min_std, self.config.max_std) @@ -281,9 +263,7 @@ class TDMPCPolicy(PreTrainedPolicy): # Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax # scores from the last iteration. - actions = elite_actions[ - :, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size) - ] + actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)] return actions @@ -306,8 +286,7 @@ class TDMPCPolicy(PreTrainedPolicy): # of the FOWM paper. if self.config.uncertainty_regularizer_coeff > 0: regularization = -( - self.config.uncertainty_regularizer_coeff - * self.model.Qs(z, actions[t]).std(0) + self.config.uncertainty_regularizer_coeff * self.model.Qs(z, actions[t]).std(0) ) else: regularization = 0 @@ -328,9 +307,7 @@ class TDMPCPolicy(PreTrainedPolicy): G += ( running_discount * torch.min( - terminal_values[ - torch.randint(0, self.config.q_ensemble_size, size=(2,)) - ], + terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0, )[0] ) @@ -338,11 +315,7 @@ class TDMPCPolicy(PreTrainedPolicy): G += running_discount * torch.min(terminal_values, dim=0)[0] # Finally, also regularize the terminal value. if self.config.uncertainty_regularizer_coeff > 0: - G -= ( - running_discount - * self.config.uncertainty_regularizer_coeff - * terminal_values.std(0) - ) + G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0) return G def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: @@ -354,9 +327,7 @@ class TDMPCPolicy(PreTrainedPolicy): batch = self.normalize_inputs(batch) if self.config.image_features: - batch = dict( - batch - ) # shallow copy so that adding a key doesn't modify the original + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch["observation.image"] = batch[next(iter(self.config.image_features))] batch = self.normalize_targets(batch) @@ -388,29 +359,21 @@ class TDMPCPolicy(PreTrainedPolicy): current_observation[k] = observations[k][0] next_observations[k] = observations[k][1:] horizon, batch_size = next_observations[ - "observation.image" - if self.config.image_features - else "observation.environment_state" + "observation.image" if self.config.image_features else "observation.environment_state" ].shape[:2] # Run latent rollout using the latent dynamics model and policy model. # Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action # gives us a next `z`. batch_size = batch["index"].shape[0] - z_preds = torch.empty( - horizon + 1, batch_size, self.config.latent_dim, device=device - ) + z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device) z_preds[0] = self.model.encode(current_observation) reward_preds = torch.empty_like(reward, device=device) for t in range(horizon): - z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward( - z_preds[t], action[t] - ) + z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t]) # Compute Q and V value predictions based on the latent rollout. - q_preds_ensemble = self.model.Qs( - z_preds[:-1], action - ) # (ensemble, horizon, batch) + q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch) v_preds = self.model.V(z_preds[:-1]) info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()}) @@ -424,14 +387,10 @@ class TDMPCPolicy(PreTrainedPolicy): # actions (not actions estimated by π). # Note: Here we do not use self.model_target, but self.model. This is to follow the original code # and the FOWM paper. - q_targets = reward + self.config.discount * self.model.V( - self.model.encode(next_observations) - ) + q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations)) # From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we # are using them to compute loss for V. - v_targets = self.model_target.Qs( - z_preds[:-1].detach(), action, return_min=True - ) + v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True) # Compute losses. # Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the @@ -474,9 +433,7 @@ class TDMPCPolicy(PreTrainedPolicy): temporal_loss_coeffs * F.mse_loss( q_preds_ensemble, - einops.repeat( - q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0] - ), + einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]), reduction="none", ).sum(0) # sum over ensemble # `q_preds_ensemble` depends on the first observation and the actions. @@ -514,14 +471,12 @@ class TDMPCPolicy(PreTrainedPolicy): z_preds = z_preds.detach() # Use stopgrad for the advantage calculation. with torch.no_grad(): - advantage = self.model_target.Qs( - z_preds[:-1], action, return_min=True - ) - self.model.V(z_preds[:-1]) + advantage = self.model_target.Qs(z_preds[:-1], action, return_min=True) - self.model.V( + z_preds[:-1] + ) info["advantage"] = advantage[0] # (t, b) - exp_advantage = torch.clamp( - torch.exp(advantage * self.config.advantage_scaling), max=100.0 - ) + exp_advantage = torch.clamp(torch.exp(advantage * self.config.advantage_scaling), max=100.0) action_preds = self.model.pi(z_preds[:-1]) # (t, b, a) # Calculate the MSE between the actions and the action predictions. # Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation @@ -575,9 +530,7 @@ class TDMPCPolicy(PreTrainedPolicy): # Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA # update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code # we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995) - update_ema_parameters( - self.model_target, self.model, self.config.target_model_momentum - ) + update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum) class TDMPCTOLD(nn.Module): @@ -588,9 +541,7 @@ class TDMPCTOLD(nn.Module): self.config = config self._encoder = TDMPCObservationEncoder(config) self._dynamics = nn.Sequential( - nn.Linear( - config.latent_dim + config.action_feature.shape[0], config.mlp_dim - ), + nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim), nn.LayerNorm(config.mlp_dim), nn.Mish(), nn.Linear(config.mlp_dim, config.mlp_dim), @@ -601,9 +552,7 @@ class TDMPCTOLD(nn.Module): nn.Sigmoid(), ) self._reward = nn.Sequential( - nn.Linear( - config.latent_dim + config.action_feature.shape[0], config.mlp_dim - ), + nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim), nn.LayerNorm(config.mlp_dim), nn.Mish(), nn.Linear(config.mlp_dim, config.mlp_dim), @@ -671,9 +620,7 @@ class TDMPCTOLD(nn.Module): "Sanity check. The last linear layer needs 0 initialization on weights." ) nn.init.zeros_(m[-1].weight) - nn.init.zeros_( - m[-1].bias - ) # this has already been done, but keep this line here for good measure + nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure def encode(self, obs: dict[str, Tensor]) -> Tensor: """Encodes an observation into its latent representation.""" @@ -812,9 +759,7 @@ class TDMPCObservationEncoder(nn.Module): if config.robot_state_feature: self.state_enc_layers = nn.Sequential( - nn.Linear( - config.robot_state_feature.shape[0], config.state_encoder_hidden_dim - ), + nn.Linear(config.robot_state_feature.shape[0], config.state_encoder_hidden_dim), nn.ELU(), nn.Linear(config.state_encoder_hidden_dim, config.latent_dim), nn.LayerNorm(config.latent_dim), @@ -823,9 +768,7 @@ class TDMPCObservationEncoder(nn.Module): if config.env_state_feature: self.env_state_enc_layers = nn.Sequential( - nn.Linear( - config.env_state_feature.shape[0], config.state_encoder_hidden_dim - ), + nn.Linear(config.env_state_feature.shape[0], config.state_encoder_hidden_dim), nn.ELU(), nn.Linear(config.state_encoder_hidden_dim, config.latent_dim), nn.LayerNorm(config.latent_dim), @@ -898,10 +841,7 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float): assert n_p_ema == n_p, "Parameter names don't match for EMA model update" if isinstance(p, dict): raise RuntimeError("Dict parameter not supported") - if ( - isinstance(module, nn.modules.batchnorm._BatchNorm) - or not p.requires_grad - ): + if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad: # Copy BatchNorm parameters, and non-trainable parameters directly. p_ema.copy_(p.to(dtype=p_ema.dtype).data) with torch.no_grad(): @@ -909,9 +849,7 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float): p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha) -def flatten_forward_unflatten( - fn: Callable[[Tensor], Tensor], image_tensor: Tensor -) -> Tensor: +def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: """Helper to temporarily flatten extra dims at the start of the image tensor. Args: diff --git a/lerobot/common/policies/vqbet/configuration_vqbet.py b/lerobot/common/policies/vqbet/configuration_vqbet.py index bf0cb2d0..9b74c977 100644 --- a/lerobot/common/policies/vqbet/configuration_vqbet.py +++ b/lerobot/common/policies/vqbet/configuration_vqbet.py @@ -172,10 +172,7 @@ class VQBeTConfig(PreTrainedConfig): if self.crop_shape is not None: for key, image_ft in self.image_features.items(): - if ( - self.crop_shape[0] > image_ft.shape[1] - or self.crop_shape[1] > image_ft.shape[2] - ): + if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]: raise ValueError( f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} " f"for `crop_shape` and {image_ft.shape} for " diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index a26b20d5..a84c8bc8 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -64,9 +64,7 @@ class VQBeTPolicy(PreTrainedPolicy): config.validate_features() self.config = config - self.normalize_inputs = Normalize( - config.input_features, config.normalization_mapping, dataset_stats - ) + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) self.normalize_targets = Normalize( config.output_features, config.normalization_mapping, dataset_stats ) @@ -97,17 +95,11 @@ class VQBeTPolicy(PreTrainedPolicy): if self.config.sequentially_select: decay_params = ( decay_params - + list( - self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters() - ) - + list( - self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters() - ) + + list(self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters()) + + list(self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters()) ) else: - decay_params = decay_params + list( - self.vqbet.action_head.map_to_cbet_preds_bin.parameters() - ) + decay_params = decay_params + list(self.vqbet.action_head.map_to_cbet_preds_bin.parameters()) return [ { @@ -145,12 +137,8 @@ class VQBeTPolicy(PreTrainedPolicy): """ batch = self.normalize_inputs(batch) - batch = dict( - batch - ) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = torch.stack( - [batch[key] for key in self.config.image_features], dim=-4 - ) + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) # Note: It's important that this happens after stacking the images into a single key. self._queues = populate_queues(self._queues, batch) @@ -161,14 +149,8 @@ class VQBeTPolicy(PreTrainedPolicy): ) if len(self._queues["action"]) == 0: - batch = { - k: torch.stack(list(self._queues[k]), dim=1) - for k in batch - if k in self._queues - } - actions = self.vqbet(batch, rollout=True)[ - :, : self.config.action_chunk_size - ] + batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} + actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size] # the dimension of returned action is (batch_size, action_chunk_size, action_dim) actions = self.unnormalize_outputs({"action": actions})["action"] @@ -181,12 +163,8 @@ class VQBeTPolicy(PreTrainedPolicy): def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) - batch = dict( - batch - ) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = torch.stack( - [batch[key] for key in self.config.image_features], dim=-4 - ) + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) batch = self.normalize_targets(batch) # VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181) if not self.vqbet.action_head.vqvae_model.discretized.item(): @@ -194,9 +172,7 @@ class VQBeTPolicy(PreTrainedPolicy): # n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`. # n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree). loss, n_different_codes, n_different_combinations, recon_l1_error = ( - self.vqbet.action_head.discretize( - self.config.n_vqvae_training_steps, batch["action"] - ) + self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"]) ) return loss, { "n_different_codes": n_different_codes, @@ -253,9 +229,7 @@ class SpatialSoftmax(nn.Module): # we could use torch.linspace directly but that seems to behave slightly differently than numpy # and causes a small degradation in pc_success of pre-trained models. - pos_x, pos_y = np.meshgrid( - np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h) - ) + pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)) pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float() pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float() # register as buffer so it's moved to the correct device. @@ -370,12 +344,7 @@ class VQBeTModel(nn.Module): num_tokens = self.config.n_action_pred_token + self.config.n_obs_steps - 1 self.register_buffer( "select_target_actions_indices", - torch.row_stack( - [ - torch.arange(i, i + self.config.action_chunk_size) - for i in range(num_tokens) - ] - ), + torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]), ) def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]: @@ -406,19 +375,13 @@ class VQBeTModel(nn.Module): input_tokens.append( self.state_projector(batch["observation.state"]) ) # (batch, obs_step, projection dims) - input_tokens.append( - einops.repeat( - self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps - ) - ) + input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps)) # Interleave tokens by stacking and rearranging. input_tokens = torch.stack(input_tokens, dim=2) input_tokens = einops.rearrange(input_tokens, "b n t d -> b (n t) d") len_additional_action_token = self.config.n_action_pred_token - 1 - future_action_tokens = self.action_token.repeat( - batch_size, len_additional_action_token, 1 - ) + future_action_tokens = self.action_token.repeat(batch_size, len_additional_action_token, 1) # add additional action query tokens for predicting future action chunks input_tokens = torch.cat([input_tokens, future_action_tokens], dim=1) @@ -427,9 +390,9 @@ class VQBeTModel(nn.Module): features = self.policy(input_tokens) # len(self.config.input_features) is the number of different observation modes. # this line gets the index of action prompt tokens. - historical_act_pred_index = np.arange(0, n_obs_steps) * ( - len(self.config.input_features) + 1 - ) + len(self.config.input_features) + historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_features) + 1) + len( + self.config.input_features + ) # only extract the output tokens at the position of action query: # Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models, @@ -449,15 +412,13 @@ class VQBeTModel(nn.Module): action_head_output = self.action_head(features) # if rollout, VQ-BeT don't calculate loss if rollout: - return action_head_output["predicted_action"][ - :, n_obs_steps - 1, : - ].reshape(batch_size, self.config.action_chunk_size, -1) + return action_head_output["predicted_action"][:, n_obs_steps - 1, :].reshape( + batch_size, self.config.action_chunk_size, -1 + ) # else, it calculate overall loss (bin prediction loss, and offset loss) else: output = batch["action"][:, self.select_target_actions_indices] - loss = self.action_head.loss_fn( - action_head_output, output, reduction="mean" - ) + loss = self.action_head.loss_fn(action_head_output, output, reduction="mean") return action_head_output, loss @@ -492,9 +453,7 @@ class VQBeTHead(nn.Module): else: self.map_to_cbet_preds_bin = MLP( in_channels=config.gpt_output_dim, - hidden_channels=[ - self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed - ], + hidden_channels=[self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed], ) self.map_to_cbet_preds_offset = MLP( in_channels=config.gpt_output_dim, @@ -521,10 +480,7 @@ class VQBeTHead(nn.Module): loss, metric = self.vqvae_model.vqvae_forward(actions) n_different_codes = sum( - [ - len(torch.unique(metric[2][:, i])) - for i in range(self.vqvae_model.vqvae_num_layers) - ] + [len(torch.unique(metric[2][:, i])) for i in range(self.vqvae_model.vqvae_num_layers)] ) n_different_combinations = len(torch.unique(metric[2], dim=0)) recon_l1_error = metric[0].detach().cpu().item() @@ -585,18 +541,12 @@ class VQBeTHead(nn.Module): cbet_secondary_logits / self.config.bet_softmax_temperature, dim=-1 ) sampled_secondary_centers = einops.rearrange( - torch.multinomial( - cbet_secondary_probs.view(-1, choices), num_samples=1 - ), + torch.multinomial(cbet_secondary_probs.view(-1, choices), num_samples=1), "(NT) 1 -> NT", NT=NT, ) - sampled_centers = torch.stack( - (sampled_primary_centers, sampled_secondary_centers), axis=1 - ) - cbet_logits = torch.stack( - [cbet_primary_logits, cbet_secondary_logits], dim=1 - ) + sampled_centers = torch.stack((sampled_primary_centers, sampled_secondary_centers), axis=1) + cbet_logits = torch.stack([cbet_primary_logits, cbet_secondary_logits], dim=1) # if self.config.sequentially_select is False, bin prediction head samples primary and secondary code at once. else: cbet_logits = self.map_to_cbet_preds_bin(x) @@ -605,9 +555,7 @@ class VQBeTHead(nn.Module): "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers, ) - cbet_probs = torch.softmax( - cbet_logits / self.config.bet_softmax_temperature, dim=-1 - ) + cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1) NT, G, choices = cbet_probs.shape sampled_centers = einops.rearrange( torch.multinomial(cbet_probs.view(-1, choices), num_samples=1), @@ -627,17 +575,9 @@ class VQBeTHead(nn.Module): sampled_offsets = sampled_offsets.sum(dim=1) with torch.no_grad(): # Get the centroids (= vectors corresponding to the codes) of each layer to pass it through RVQ decoder - return_decoder_input = ( - self.vqvae_model.get_embeddings_from_code(sampled_centers) - .clone() - .detach() - ) + return_decoder_input = self.vqvae_model.get_embeddings_from_code(sampled_centers).clone().detach() # pass the centroids through decoder to get actions. - decoded_action = ( - self.vqvae_model.get_action_from_latent(return_decoder_input) - .clone() - .detach() - ) + decoded_action = self.vqvae_model.get_action_from_latent(return_decoder_input).clone().detach() # reshaped extracted offset to match with decoded centroids sampled_offsets = einops.rearrange( sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size @@ -686,9 +626,7 @@ class VQBeTHead(nn.Module): # Figure out the loss for the actions. # First, we need to find the closest cluster center for each ground truth action. with torch.no_grad(): - state_vq, action_bins = self.vqvae_model.get_code( - action_seq - ) # action_bins: NT, G + state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G # Now we can compute the loss. @@ -711,12 +649,8 @@ class VQBeTHead(nn.Module): + cbet_loss2 * self.config.secondary_code_loss_weight ) - equal_primary_code_rate = torch.sum( - (action_bins[:, 0] == sampled_centers[:, 0]).int() - ) / (NT) - equal_secondary_code_rate = torch.sum( - (action_bins[:, 1] == sampled_centers[:, 1]).int() - ) / (NT) + equal_primary_code_rate = torch.sum((action_bins[:, 0] == sampled_centers[:, 0]).int()) / (NT) + equal_secondary_code_rate = torch.sum((action_bins[:, 1] == sampled_centers[:, 1]).int()) / (NT) action_mse_error = torch.mean((action_seq - predicted_action) ** 2) vq_action_error = torch.mean(torch.abs(action_seq - decoded_action)) @@ -730,9 +664,7 @@ class VQBeTHead(nn.Module): "classification_loss": cbet_loss.detach().cpu().item(), "offset_loss": offset_loss.detach().cpu().item(), "equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(), - "equal_secondary_code_rate": equal_secondary_code_rate.detach() - .cpu() - .item(), + "equal_secondary_code_rate": equal_secondary_code_rate.detach().cpu().item(), "vq_action_error": vq_action_error.detach().cpu().item(), "offset_action_error": offset_action_error.detach().cpu().item(), "action_error_max": action_error_max.detach().cpu().item(), @@ -757,9 +689,7 @@ class VQBeTRgbEncoder(nn.Module): # Always use center crop for eval self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape) if config.crop_is_random: - self.maybe_random_crop = torchvision.transforms.RandomCrop( - config.crop_shape - ) + self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape) else: self.maybe_random_crop = self.center_crop else: @@ -780,9 +710,7 @@ class VQBeTRgbEncoder(nn.Module): self.backbone = _replace_submodules( root_module=self.backbone, predicate=lambda x: isinstance(x, nn.BatchNorm2d), - func=lambda x: nn.GroupNorm( - num_groups=x.num_features // 16, num_channels=x.num_features - ), + func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features), ) # Set up pooling and final layers. @@ -792,15 +720,11 @@ class VQBeTRgbEncoder(nn.Module): # height and width from `config.image_features`. images_shape = next(iter(config.image_features.values())).shape - dummy_shape_h_w = ( - config.crop_shape if config.crop_shape is not None else images_shape[1:] - ) + dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:] dummy_shape = (1, images_shape[0], *dummy_shape_h_w) feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:] - self.pool = SpatialSoftmax( - feature_map_shape, num_kp=config.spatial_softmax_num_keypoints - ) + self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints) self.feature_dim = config.spatial_softmax_num_keypoints * 2 self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim) self.relu = nn.ReLU() @@ -842,11 +766,7 @@ def _replace_submodules( if predicate(root_module): return func(root_module) - replace_list = [ - k.split(".") - for k, m in root_module.named_modules(remove_duplicate=True) - if predicate(m) - ] + replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] for *parents, k in replace_list: parent_module = root_module if len(parents) > 0: @@ -861,9 +781,7 @@ def _replace_submodules( else: setattr(parent_module, k, tgt_module) # verify that all BN are replaced - assert not any( - predicate(m) for _, m in root_module.named_modules(remove_duplicate=True) - ) + assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)) return root_module @@ -896,8 +814,7 @@ class VqVae(nn.Module): ) self.encoder = MLP( - in_channels=self.config.action_feature.shape[0] - * self.config.action_chunk_size, + in_channels=self.config.action_feature.shape[0] * self.config.action_chunk_size, hidden_channels=[ config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, @@ -925,13 +842,9 @@ class VqVae(nn.Module): # given latent vector, this function outputs the decoded action. output = self.decoder(latent) if self.config.action_chunk_size == 1: - return einops.rearrange( - output, "N (T A) -> N T A", A=self.config.action_feature.shape[0] - ) + return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0]) else: - return einops.rearrange( - output, "N (T A) -> N T A", A=self.config.action_feature.shape[0] - ) + return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0]) def get_code(self, state): # in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181) diff --git a/lerobot/common/policies/vqbet/vqbet_utils.py b/lerobot/common/policies/vqbet/vqbet_utils.py index 3fed3a15..248a6ef8 100644 --- a/lerobot/common/policies/vqbet/vqbet_utils.py +++ b/lerobot/common/policies/vqbet/vqbet_utils.py @@ -123,15 +123,9 @@ class CausalSelfAttention(nn.Module): # calculate query, key, values for all heads in batch and move head forward to be the batch dim q, k, v = self.c_attn(x).split(self.gpt_hidden_dim, dim=2) - k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose( - 1, 2 - ) # (B, nh, T, hs) - q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose( - 1, 2 - ) # (B, nh, T, hs) - v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose( - 1, 2 - ) # (B, nh, T, hs) + k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs) # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) @@ -139,9 +133,7 @@ class CausalSelfAttention(nn.Module): att = F.softmax(att, dim=-1) att = self.attn_dropout(att) y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - y = ( - y.transpose(1, 2).contiguous().view(B, T, C) - ) # re-assemble all head outputs side by side + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side # output projection y = self.resid_dropout(self.c_proj(y)) @@ -197,16 +189,12 @@ class GPT(nn.Module): "ln_f": nn.LayerNorm(config.gpt_hidden_dim), } ) - self.lm_head = nn.Linear( - config.gpt_hidden_dim, config.gpt_output_dim, bias=False - ) + self.lm_head = nn.Linear(config.gpt_hidden_dim, config.gpt_output_dim, bias=False) # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper self.apply(self._init_weights) for pn, p in self.named_parameters(): if pn.endswith("c_proj.weight"): - torch.nn.init.normal_( - p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer) - ) + torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer)) # report number of parameters n_params = sum(p.numel() for p in self.parameters()) @@ -220,17 +208,11 @@ class GPT(nn.Module): ) # positional encodings that are added to the input embeddings - pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze( - 0 - ) # shape (1, t) + pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) # forward the GPT model itself - tok_emb = self.transformer.wte( - input - ) # token embeddings of shape (b, t, gpt_hidden_dim) - pos_emb = self.transformer.wpe( - pos - ) # position embeddings of shape (1, t, gpt_hidden_dim) + tok_emb = self.transformer.wte(input) # token embeddings of shape (b, t, gpt_hidden_dim) + pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim) x = self.transformer.drop(tok_emb + pos_emb) for block in self.transformer.h: x = block(x) @@ -255,9 +237,7 @@ class GPT(nn.Module): # but want to use a smaller block size for some smaller, simpler model assert gpt_block_size <= self.config.gpt_block_size self.config.gpt_block_size = gpt_block_size - self.transformer.wpe.weight = nn.Parameter( - self.transformer.wpe.weight[:gpt_block_size] - ) + self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:gpt_block_size]) for block in self.transformer.h: block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size] @@ -290,10 +270,8 @@ class GPT(nn.Module): param_dict = dict(self.named_parameters()) inter_params = decay & no_decay union_params = decay | no_decay - assert len(inter_params) == 0, ( - "parameters {} made it into both decay/no_decay sets!".format( - str(inter_params) - ) + assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format( + str(inter_params) ) assert len(param_dict.keys() - union_params) == 0, ( "parameters {} were not separated into either decay/no_decay set!".format( @@ -390,12 +368,8 @@ class ResidualVQ(nn.Module): codebook_input_dim = codebook_dim * heads requires_projection = codebook_input_dim != dim - self.project_in = ( - nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() - ) - self.project_out = ( - nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() - ) + self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() + self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() self.num_quantizers = num_quantizers @@ -477,9 +451,7 @@ class ResidualVQ(nn.Module): return all_codes - def forward( - self, x, indices=None, return_all_codes=False, sample_codebook_temp=None - ): + def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp=None): """ For given input tensor x, this function will return the quantized output, the indices of the quantized output, and the loss. First, the input tensor x is projected to the codebook dimension. Then, the input tensor x is passed through Nq layers of VectorQuantize. @@ -508,17 +480,13 @@ class ResidualVQ(nn.Module): ) ce_losses = [] - should_quantize_dropout = ( - self.training and self.quantize_dropout and not return_loss - ) + should_quantize_dropout = self.training and self.quantize_dropout and not return_loss # sample a layer index at which to dropout further residual quantization # also prepare null indices and loss if should_quantize_dropout: - rand_quantize_dropout_index = randrange( - self.quantize_dropout_cutoff_index, num_quant - ) + rand_quantize_dropout_index = randrange(self.quantize_dropout_cutoff_index, num_quant) if quant_dropout_multiple_of != 1: rand_quantize_dropout_index = ( @@ -527,23 +495,14 @@ class ResidualVQ(nn.Module): - 1 ) - null_indices_shape = ( - (x.shape[0], *x.shape[-2:]) - if self.accept_image_fmap - else tuple(x.shape[:2]) - ) - null_indices = torch.full( - null_indices_shape, -1.0, device=device, dtype=torch.long - ) + null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2]) + null_indices = torch.full(null_indices_shape, -1.0, device=device, dtype=torch.long) null_loss = torch.full((1,), 0.0, device=device, dtype=x.dtype) # go through the layers for quantizer_index, layer in enumerate(self.layers): - if ( - should_quantize_dropout - and quantizer_index > rand_quantize_dropout_index - ): + if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index: all_indices.append(null_indices) all_losses.append(null_loss) continue @@ -583,9 +542,7 @@ class ResidualVQ(nn.Module): # stack all losses and indices - all_losses, all_indices = map( - partial(torch.stack, dim=-1), (all_losses, all_indices) - ) + all_losses, all_indices = map(partial(torch.stack, dim=-1), (all_losses, all_indices)) ret = (quantized_out, all_indices, all_losses) @@ -645,12 +602,8 @@ class VectorQuantize(nn.Module): codebook_input_dim = codebook_dim * heads requires_projection = codebook_input_dim != dim - self.project_in = ( - nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() - ) - self.project_out = ( - nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() - ) + self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() + self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() self.eps = eps self.commitment_weight = commitment_weight @@ -664,14 +617,10 @@ class VectorQuantize(nn.Module): self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only self.orthogonal_reg_max_codes = orthogonal_reg_max_codes - assert not (ema_update and learnable_codebook), ( - "learnable codebook not compatible with EMA update" - ) + assert not (ema_update and learnable_codebook), "learnable codebook not compatible with EMA update" assert 0 <= sync_update_v <= 1.0 - assert not (sync_update_v > 0.0 and not learnable_codebook), ( - "learnable codebook must be turned on" - ) + assert not (sync_update_v > 0.0 and not learnable_codebook), "learnable codebook must be turned on" self.sync_update_v = sync_update_v @@ -683,9 +632,7 @@ class VectorQuantize(nn.Module): ) if sync_codebook is None: - sync_codebook = ( - distributed.is_initialized() and distributed.get_world_size() > 1 - ) + sync_codebook = distributed.is_initialized() and distributed.get_world_size() > 1 codebook_kwargs = { "dim": codebook_dim, @@ -850,17 +797,11 @@ class VectorQuantize(nn.Module): # quantize again - quantize, embed_ind, distances = self._codebook( - x, **codebook_forward_kwargs - ) + quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs) if self.training: # determine code to use for commitment loss - maybe_detach = ( - torch.detach - if not self.learnable_codebook or freeze_codebook - else identity - ) + maybe_detach = torch.detach if not self.learnable_codebook or freeze_codebook else identity commit_quantize = maybe_detach(quantize) @@ -870,9 +811,7 @@ class VectorQuantize(nn.Module): if self.sync_update_v > 0.0: # (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf - quantize = quantize + self.sync_update_v * ( - quantize - quantize.detach() - ) + quantize = quantize + self.sync_update_v * (quantize - quantize.detach()) # function for calculating cross entropy loss to distance matrix # used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss @@ -905,9 +844,7 @@ class VectorQuantize(nn.Module): embed_ind = rearrange(embed_ind, "1 (b h) n -> b n h", h=heads) if self.accept_image_fmap: - embed_ind = rearrange( - embed_ind, "b (h w) ... -> b h w ...", h=height, w=width - ) + embed_ind = rearrange(embed_ind, "b (h w) ... -> b h w ...", h=height, w=width) if only_one: embed_ind = rearrange(embed_ind, "b 1 -> b") @@ -961,12 +898,8 @@ class VectorQuantize(nn.Module): num_codes = codebook.shape[-2] - if ( - self.orthogonal_reg_max_codes is not None - ) and num_codes > self.orthogonal_reg_max_codes: - rand_ids = torch.randperm(num_codes, device=device)[ - : self.orthogonal_reg_max_codes - ] + if (self.orthogonal_reg_max_codes is not None) and num_codes > self.orthogonal_reg_max_codes: + rand_ids = torch.randperm(num_codes, device=device)[: self.orthogonal_reg_max_codes] codebook = codebook[:, rand_ids] orthogonal_reg_loss = orthogonal_loss_fn(codebook) @@ -998,9 +931,7 @@ class VectorQuantize(nn.Module): # if masking, only return quantized for where mask has True if mask is not None: - quantize = torch.where( - rearrange(mask, "... -> ... 1"), quantize, orig_input - ) + quantize = torch.where(rearrange(mask, "... -> ... 1"), quantize, orig_input) return quantize, embed_ind, loss @@ -1110,9 +1041,7 @@ def sample_vectors(samples, num): def batched_sample_vectors(samples, num): - return torch.stack( - [sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0 - ) + return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0) def pad_shape(shape, size, dim=0): @@ -1163,9 +1092,7 @@ def sample_vectors_distributed(local_samples, num): all_num_samples = all_gather_sizes(local_samples, dim=0) if rank == 0: - samples_per_rank = sample_multinomial( - num, all_num_samples / all_num_samples.sum() - ) + samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum()) else: samples_per_rank = torch.empty_like(all_num_samples) @@ -1278,9 +1205,7 @@ class EuclideanCodebook(nn.Module): self.eps = eps self.threshold_ema_dead_code = threshold_ema_dead_code self.reset_cluster_size = ( - reset_cluster_size - if (reset_cluster_size is not None) - else threshold_ema_dead_code + reset_cluster_size if (reset_cluster_size is not None) else threshold_ema_dead_code ) assert callable(gumbel_sample) @@ -1291,14 +1216,8 @@ class EuclideanCodebook(nn.Module): "kmeans init is not compatible with multiple codebooks in distributed environment for now" ) - self.sample_fn = ( - sample_vectors_distributed - if use_ddp and sync_kmeans - else batched_sample_vectors - ) - self.kmeans_all_reduce_fn = ( - distributed.all_reduce if use_ddp and sync_kmeans else noop - ) + self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors + self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop self.all_reduce_fn = distributed.all_reduce if use_ddp else noop self.register_buffer("initted", torch.Tensor([not kmeans_init])) @@ -1437,9 +1356,7 @@ class EuclideanCodebook(nn.Module): distributed.all_reduce(variance_number) batch_variance = variance_number / num_vectors - self.update_with_decay( - "batch_variance", batch_variance, self.affine_param_batch_decay - ) + self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay) def replace(self, batch_samples, batch_mask): for ind, (samples, mask) in enumerate( @@ -1448,9 +1365,7 @@ class EuclideanCodebook(nn.Module): if not torch.any(mask): continue - sampled = self.sample_fn( - rearrange(samples, "... -> 1 ..."), mask.sum().item() - ) + sampled = self.sample_fn(rearrange(samples, "... -> 1 ..."), mask.sum().item()) sampled = rearrange(sampled, "1 ... -> ...") self.embed.data[ind][mask] = sampled @@ -1474,9 +1389,7 @@ class EuclideanCodebook(nn.Module): def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False): needs_codebook_dim = x.ndim < 4 sample_codebook_temp = ( - sample_codebook_temp - if (sample_codebook_temp is not None) - else self.sample_codebook_temp + sample_codebook_temp if (sample_codebook_temp is not None) else self.sample_codebook_temp ) x = x.float() @@ -1504,9 +1417,7 @@ class EuclideanCodebook(nn.Module): if self.affine_param: codebook_std = self.codebook_variance.clamp(min=1e-5).sqrt() batch_std = self.batch_variance.clamp(min=1e-5).sqrt() - embed = (embed - self.codebook_mean) * ( - batch_std / codebook_std - ) + self.batch_mean + embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean dist = -cdist(flatten, embed) @@ -1524,9 +1435,7 @@ class EuclideanCodebook(nn.Module): if self.training and self.ema_update and not freeze_codebook: if self.affine_param: - flatten = (flatten - self.batch_mean) * ( - codebook_std / batch_std - ) + self.codebook_mean + flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean if mask is not None: embed_onehot[~mask] = 0.0 @@ -1549,9 +1458,7 @@ class EuclideanCodebook(nn.Module): self.expire_codes_(x) if needs_codebook_dim: - quantize, embed_ind = tuple( - rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind) - ) + quantize, embed_ind = tuple(rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind)) dist = unpack_one(dist, ps, "h * d") diff --git a/lerobot/common/robot_devices/cameras/configs.py b/lerobot/common/robot_devices/cameras/configs.py index ccbc0268..013419a9 100644 --- a/lerobot/common/robot_devices/cameras/configs.py +++ b/lerobot/common/robot_devices/cameras/configs.py @@ -57,9 +57,7 @@ class OpenCVCameraConfig(CameraConfig): self.channels = 3 if self.rotation not in [-90, None, 90, 180]: - raise ValueError( - f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})" - ) + raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})") @CameraConfig.register_subclass("intelrealsense") @@ -104,12 +102,8 @@ class IntelRealSenseCameraConfig(CameraConfig): self.channels = 3 - at_least_one_is_not_none = ( - self.fps is not None or self.width is not None or self.height is not None - ) - at_least_one_is_none = ( - self.fps is None or self.width is None or self.height is None - ) + at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None + at_least_one_is_none = self.fps is None or self.width is None or self.height is None if at_least_one_is_not_none and at_least_one_is_none: raise ValueError( "For `fps`, `width` and `height`, either all of them need to be set, or none of them, " @@ -117,6 +111,4 @@ class IntelRealSenseCameraConfig(CameraConfig): ) if self.rotation not in [-90, None, 90, 180]: - raise ValueError( - f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})" - ) + raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})") diff --git a/lerobot/common/robot_devices/cameras/intelrealsense.py b/lerobot/common/robot_devices/cameras/intelrealsense.py index 2ae220a0..e93044e1 100644 --- a/lerobot/common/robot_devices/cameras/intelrealsense.py +++ b/lerobot/common/robot_devices/cameras/intelrealsense.py @@ -79,9 +79,7 @@ def save_image(img_array, serial_number, frame_index, images_dir): img.save(str(path), quality=100) logging.info(f"Saved image: {path}") except Exception as e: - logging.error( - f"Failed to save image for camera {serial_number} frame {frame_index}: {e}" - ) + logging.error(f"Failed to save image for camera {serial_number} frame {frame_index}: {e}") def save_images_from_cameras( @@ -159,9 +157,7 @@ def save_images_from_cameras( if time.perf_counter() - start_time > record_time_s: break - print( - f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}" - ) + print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}") frame_index += 1 finally: @@ -279,9 +275,7 @@ class IntelRealSenseCamera: f"Multiple {name} cameras have been detected. Please use their serial number to instantiate them." ) - name_to_serial_dict = { - cam["name"]: cam["serial_number"] for cam in camera_infos - } + name_to_serial_dict = {cam["name"]: cam["serial_number"] for cam in camera_infos} cam_sn = name_to_serial_dict[name] return cam_sn @@ -353,9 +347,7 @@ class IntelRealSenseCamera: actual_height = color_profile.height() # Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30) - if self.fps is not None and not math.isclose( - self.fps, actual_fps, rel_tol=1e-3 - ): + if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3): # Using `OSError` since it's a broad that encompasses issues related to device communication raise OSError( f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}." @@ -375,9 +367,7 @@ class IntelRealSenseCamera: self.is_connected = True - def read( - self, temporary_color: str | None = None - ) -> np.ndarray | tuple[np.ndarray, np.ndarray]: + def read(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndarray, np.ndarray]: """Read a frame from the camera returned in the format height x width x channels (e.g. 480 x 640 x 3) of type `np.uint8`, contrarily to the pytorch format which is float channel first. @@ -404,15 +394,11 @@ class IntelRealSenseCamera: color_frame = frame.get_color_frame() if not color_frame: - raise OSError( - f"Can't capture color image from IntelRealSenseCamera({self.serial_number})." - ) + raise OSError(f"Can't capture color image from IntelRealSenseCamera({self.serial_number}).") color_image = np.asanyarray(color_frame.get_data()) - requested_color_mode = ( - self.color_mode if temporary_color is None else temporary_color - ) + requested_color_mode = self.color_mode if temporary_color is None else temporary_color if requested_color_mode not in ["rgb", "bgr"]: raise ValueError( f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided." @@ -440,9 +426,7 @@ class IntelRealSenseCamera: if self.use_depth: depth_frame = frame.get_depth_frame() if not depth_frame: - raise OSError( - f"Can't capture depth image from IntelRealSenseCamera({self.serial_number})." - ) + raise OSError(f"Can't capture depth image from IntelRealSenseCamera({self.serial_number}).") depth_map = np.asanyarray(depth_frame.get_data()) @@ -484,9 +468,7 @@ class IntelRealSenseCamera: # TODO(rcadene, aliberts): intelrealsense has diverged compared to opencv over here num_tries += 1 time.sleep(1 / self.fps) - if num_tries > self.fps and ( - self.thread.ident is None or not self.thread.is_alive() - ): + if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()): raise Exception( "The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called." ) diff --git a/lerobot/common/robot_devices/cameras/opencv.py b/lerobot/common/robot_devices/cameras/opencv.py index 75fb4829..c9226805 100644 --- a/lerobot/common/robot_devices/cameras/opencv.py +++ b/lerobot/common/robot_devices/cameras/opencv.py @@ -45,14 +45,10 @@ from lerobot.common.utils.utils import capture_timestamp_utc MAX_OPENCV_INDEX = 60 -def find_cameras( - raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False -) -> list[dict]: +def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]: cameras = [] if platform.system() == "Linux": - print( - "Linux detected. Finding available camera indices through scanning '/dev/video*' ports" - ) + print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports") possible_ports = [str(port) for port in Path("/dev").glob("video*")] ports = _find_cameras(possible_ports, mock=mock) for port in ports: @@ -144,9 +140,7 @@ def save_images_from_cameras( print("Connecting cameras") cameras = [] for cam_idx in camera_ids: - config = OpenCVCameraConfig( - camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock - ) + config = OpenCVCameraConfig(camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock) camera = OpenCVCamera(config) camera.connect() print( @@ -186,9 +180,7 @@ def save_images_from_cameras( dt_s = time.perf_counter() - now busy_wait(1 / fps - dt_s) - print( - f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}" - ) + print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}") if time.perf_counter() - start_time > record_time_s: break @@ -245,16 +237,12 @@ class OpenCVCamera: if platform.system() == "Linux": if isinstance(self.camera_index, int): self.port = Path(f"/dev/video{self.camera_index}") - elif isinstance(self.camera_index, str) and is_valid_unix_path( - self.camera_index - ): + elif isinstance(self.camera_index, str) and is_valid_unix_path(self.camera_index): self.port = Path(self.camera_index) # Retrieve the camera index from a potentially symlinked path self.camera_index = get_camera_index_from_unix_port(self.port) else: - raise ValueError( - f"Please check the provided camera_index: {self.camera_index}" - ) + raise ValueError(f"Please check the provided camera_index: {self.camera_index}") # Store the raw (capture) resolution from the config. self.capture_width = config.width @@ -295,9 +283,7 @@ class OpenCVCamera: def connect(self): if self.is_connected: - raise RobotDeviceAlreadyConnectedError( - f"OpenCVCamera({self.camera_index}) is already connected." - ) + raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.") if self.mock: import tests.cameras.mock_cv2 as cv2 @@ -318,11 +304,7 @@ class OpenCVCamera: else cv2.CAP_ANY ) - camera_idx = ( - f"/dev/video{self.camera_index}" - if platform.system() == "Linux" - else self.camera_index - ) + camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index # First create a temporary camera trying to access `camera_index`, # and verify it is a valid camera by calling `isOpened`. tmp_camera = cv2.VideoCapture(camera_idx, backend) @@ -362,9 +344,7 @@ class OpenCVCamera: actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT) # Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30) - if self.fps is not None and not math.isclose( - self.fps, actual_fps, rel_tol=1e-3 - ): + if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3): # Using `OSError` since it's a broad that encompasses issues related to device communication raise OSError( f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}." @@ -406,9 +386,7 @@ class OpenCVCamera: if not ret: raise OSError(f"Can't capture color image from camera {self.camera_index}.") - requested_color_mode = ( - self.color_mode if temporary_color_mode is None else temporary_color_mode - ) + requested_color_mode = self.color_mode if temporary_color_mode is None else temporary_color_mode if requested_color_mode not in ["rgb", "bgr"]: raise ValueError( diff --git a/lerobot/common/robot_devices/control_configs.py b/lerobot/common/robot_devices/control_configs.py index fbecd52f..cb558c71 100644 --- a/lerobot/common/robot_devices/control_configs.py +++ b/lerobot/common/robot_devices/control_configs.py @@ -93,9 +93,7 @@ class RecordControlConfig(ControlConfig): policy_path = parser.get_path_arg("control.policy") if policy_path: cli_overrides = parser.get_cli_overrides("control.policy") - self.policy = PreTrainedConfig.from_pretrained( - policy_path, cli_overrides=cli_overrides - ) + self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) self.policy.pretrained_path = policy_path diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index f7951446..53cd508f 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -39,9 +39,7 @@ from lerobot.common.robot_devices.utils import busy_wait from lerobot.common.utils.utils import get_safe_torch_device, has_method -def log_control_info( - robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None -): +def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None): log_items = [] if episode_index is not None: log_items.append(f"ep:{episode_index}") @@ -108,9 +106,7 @@ def predict_action(observation, policy, device, use_amp): observation = copy(observation) with ( torch.inference_mode(), - torch.autocast(device_type=device.type) - if device.type == "cuda" and use_amp - else nullcontext(), + torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), ): # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension for name in observation: @@ -166,9 +162,7 @@ def init_keyboard_listener(assign_rewards=False): print("Right arrow key pressed. Exiting loop...") events["exit_early"] = True elif key == keyboard.Key.left: - print( - "Left arrow key pressed. Exiting loop and rerecord the last episode..." - ) + print("Left arrow key pressed. Exiting loop and rerecord the last episode...") events["rerecord_episode"] = True events["exit_early"] = True elif key == keyboard.Key.esc: @@ -262,9 +256,7 @@ def control_loop( raise ValueError("You need to provide a task as argument in `single_task`.") if dataset is not None and fps is not None and dataset.fps != fps: - raise ValueError( - f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps})." - ) + raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).") timestamp = 0 start_episode_t = time.perf_counter() @@ -297,9 +289,7 @@ def control_loop( dataset.add_frame(frame) # TODO(Steven): This should be more general (for RemoteRobot instead of checking the name, but anyways it will change soon) - if (display_data and not is_headless()) or ( - display_data and robot.robot_type.startswith("lekiwi") - ): + if (display_data and not is_headless()) or (display_data and robot.robot_type.startswith("lekiwi")): for k, v in action.items(): for i, vv in enumerate(v): rr.log(f"sent_{k}_{i}", rr.Scalar(vv.numpy())) @@ -389,14 +379,11 @@ def sanity_check_dataset_robot_compatibility( mismatches = [] for field, dataset_value, present_value in fields: - diff = DeepDiff( - dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"] - ) + diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"]) if diff: mismatches.append(f"{field}: expected {present_value}, got {dataset_value}") if mismatches: raise ValueError( - "Dataset metadata compatibility check failed with mismatches:\n" - + "\n".join(mismatches) + "Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches) ) diff --git a/lerobot/common/robot_devices/motors/dynamixel.py b/lerobot/common/robot_devices/motors/dynamixel.py index d477b846..ab4dc9ba 100644 --- a/lerobot/common/robot_devices/motors/dynamixel.py +++ b/lerobot/common/robot_devices/motors/dynamixel.py @@ -161,9 +161,7 @@ NUM_READ_RETRY = 10 NUM_WRITE_RETRY = 10 -def convert_degrees_to_steps( - degrees: float | np.ndarray, models: str | list[str] -) -> np.ndarray: +def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray: """This function converts the degree range to the step range for indicating motors rotation. It assumes a motor achieves a full rotation by going from -180 degree position to +180. The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation. @@ -389,9 +387,7 @@ class DynamixelMotorsBus: indices = [] for idx in tqdm.tqdm(possible_ids): try: - present_idx = self.read_with_motor_ids( - self.motor_models, [idx], "ID", num_retry=num_retry - )[0] + present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0] except ConnectionError: continue @@ -407,9 +403,7 @@ class DynamixelMotorsBus: def set_bus_baudrate(self, baudrate): present_bus_baudrate = self.port_handler.getBaudRate() if present_bus_baudrate != baudrate: - print( - f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}." - ) + print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.") self.port_handler.setBaudRate(baudrate) if self.port_handler.getBaudRate() != baudrate: @@ -430,9 +424,7 @@ class DynamixelMotorsBus: def set_calibration(self, calibration: dict[str, list]): self.calibration = calibration - def apply_calibration_autocorrect( - self, values: np.ndarray | list, motor_names: list[str] | None - ): + def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None): """This function applies the calibration, automatically detects out of range errors for motors values and attempts to correct. For more info, see docstring of `apply_calibration` and `autocorrect_calibration`. @@ -445,9 +437,7 @@ class DynamixelMotorsBus: values = self.apply_calibration(values, motor_names) return values - def apply_calibration( - self, values: np.ndarray | list, motor_names: list[str] | None - ): + def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): """Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with a "zero position" at 0 degree. @@ -522,9 +512,7 @@ class DynamixelMotorsBus: return values - def autocorrect_calibration( - self, values: np.ndarray | list, motor_names: list[str] | None - ): + def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): """This function automatically detects issues with values of motors after calibration, and correct for these issues. Some motors might have values outside of expected maximum bounds after calibration. @@ -566,23 +554,15 @@ class DynamixelMotorsBus: values[i] *= -1 # Convert from initial range to range [-180, 180] degrees - calib_val = ( - (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE - ) - in_range = (calib_val > LOWER_BOUND_DEGREE) and ( - calib_val < UPPER_BOUND_DEGREE - ) + calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE + in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE) # Solve this inequality to find the factor to shift the range into [-180, 180] degrees # values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE # - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE # (- (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= ((resolution // 2) - values[i] - homing_offset) / resolution - low_factor = ( - -(resolution // 2) - values[i] - homing_offset - ) / resolution - upp_factor = ( - (resolution // 2) - values[i] - homing_offset - ) / resolution + low_factor = (-(resolution // 2) - values[i] - homing_offset) / resolution + upp_factor = ((resolution // 2) - values[i] - homing_offset) / resolution elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: start_pos = self.calibration["start_pos"][calib_idx] @@ -590,9 +570,7 @@ class DynamixelMotorsBus: # Convert from initial range to range [0, 100] in % calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100 - in_range = (calib_val > LOWER_BOUND_LINEAR) and ( - calib_val < UPPER_BOUND_LINEAR - ) + in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR) # Solve this inequality to find the factor to shift the range into [0, 100] % # values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100 @@ -608,27 +586,19 @@ class DynamixelMotorsBus: factor = math.ceil(low_factor) if factor > upp_factor: - raise ValueError( - f"No integer found between bounds [{low_factor=}, {upp_factor=}]" - ) + raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]") else: factor = math.ceil(upp_factor) if factor > low_factor: - raise ValueError( - f"No integer found between bounds [{low_factor=}, {upp_factor=}]" - ) + raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]") if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: - out_of_range_str = ( - f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" - ) - in_range_str = ( - f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" - ) + out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" + in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" logging.warning( f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, " @@ -638,9 +608,7 @@ class DynamixelMotorsBus: # A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096. self.calibration["homing_offset"][calib_idx] += resolution * factor - def revert_calibration( - self, values: np.ndarray | list, motor_names: list[str] | None - ): + def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): """Inverse of `apply_calibration`.""" if motor_names is None: motor_names = self.motor_names @@ -679,9 +647,7 @@ class DynamixelMotorsBus: values = np.round(values).astype(np.int32) return values - def read_with_motor_ids( - self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY - ): + def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY): if self.mock: import tests.motors.mock_dynamixel_sdk as dxl else: @@ -783,9 +749,7 @@ class DynamixelMotorsBus: values = self.apply_calibration_autocorrect(values, motor_names) # log the number of seconds it took to read the data from the motors - delta_ts_name = get_log_name( - "delta_timestamp_s", "read", data_name, motor_names - ) + delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names) self.logs[delta_ts_name] = time.perf_counter() - start_time # log the utc time at which the data was received @@ -794,9 +758,7 @@ class DynamixelMotorsBus: return values - def write_with_motor_ids( - self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY - ): + def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY): if self.mock: import tests.motors.mock_dynamixel_sdk as dxl else: @@ -891,9 +853,7 @@ class DynamixelMotorsBus: ) # log the number of seconds it took to write the data to the motors - delta_ts_name = get_log_name( - "delta_timestamp_s", "write", data_name, motor_names - ) + delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names) self.logs[delta_ts_name] = time.perf_counter() - start_time # TODO(rcadene): should we log the time before sending the write command? diff --git a/lerobot/common/robot_devices/motors/feetech.py b/lerobot/common/robot_devices/motors/feetech.py index cfa249b8..65bd4147 100644 --- a/lerobot/common/robot_devices/motors/feetech.py +++ b/lerobot/common/robot_devices/motors/feetech.py @@ -140,9 +140,7 @@ NUM_READ_RETRY = 20 NUM_WRITE_RETRY = 20 -def convert_degrees_to_steps( - degrees: float | np.ndarray, models: str | list[str] -) -> np.ndarray: +def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray: """This function converts the degree range to the step range for indicating motors rotation. It assumes a motor achieves a full rotation by going from -180 degree position to +180. The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation. @@ -370,9 +368,7 @@ class FeetechMotorsBus: indices = [] for idx in tqdm.tqdm(possible_ids): try: - present_idx = self.read_with_motor_ids( - self.motor_models, [idx], "ID", num_retry=num_retry - )[0] + present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0] except ConnectionError: continue @@ -388,9 +384,7 @@ class FeetechMotorsBus: def set_bus_baudrate(self, baudrate): present_bus_baudrate = self.port_handler.getBaudRate() if present_bus_baudrate != baudrate: - print( - f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}." - ) + print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.") self.port_handler.setBaudRate(baudrate) if self.port_handler.getBaudRate() != baudrate: @@ -411,9 +405,7 @@ class FeetechMotorsBus: def set_calibration(self, calibration: dict[str, list]): self.calibration = calibration - def apply_calibration_autocorrect( - self, values: np.ndarray | list, motor_names: list[str] | None - ): + def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None): """This function apply the calibration, automatically detects out of range errors for motors values and attempt to correct. For more info, see docstring of `apply_calibration` and `autocorrect_calibration`. @@ -426,9 +418,7 @@ class FeetechMotorsBus: values = self.apply_calibration(values, motor_names) return values - def apply_calibration( - self, values: np.ndarray | list, motor_names: list[str] | None - ): + def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): """Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with a "zero position" at 0 degree. @@ -502,9 +492,7 @@ class FeetechMotorsBus: return values - def autocorrect_calibration( - self, values: np.ndarray | list, motor_names: list[str] | None - ): + def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): """This function automatically detects issues with values of motors after calibration, and correct for these issues. Some motors might have values outside of expected maximum bounds after calibration. @@ -543,26 +531,18 @@ class FeetechMotorsBus: values[i] *= -1 # Convert from initial range to range [-180, 180] degrees - calib_val = ( - (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE - ) - in_range = (calib_val > LOWER_BOUND_DEGREE) and ( - calib_val < UPPER_BOUND_DEGREE - ) + calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE + in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE) # Solve this inequality to find the factor to shift the range into [-180, 180] degrees # values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE # - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE # (- HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= (HALF_TURN_DEGREE / 180 * (resolution // 2) - values[i] - homing_offset) / resolution low_factor = ( - -HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - - values[i] - - homing_offset + -HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset ) / resolution upp_factor = ( - HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - - values[i] - - homing_offset + HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset ) / resolution elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: @@ -571,9 +551,7 @@ class FeetechMotorsBus: # Convert from initial range to range [0, 100] in % calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100 - in_range = (calib_val > LOWER_BOUND_LINEAR) and ( - calib_val < UPPER_BOUND_LINEAR - ) + in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR) # Solve this inequality to find the factor to shift the range into [0, 100] % # values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100 @@ -589,27 +567,19 @@ class FeetechMotorsBus: factor = math.ceil(low_factor) if factor > upp_factor: - raise ValueError( - f"No integer found between bounds [{low_factor=}, {upp_factor=}]" - ) + raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]") else: factor = math.ceil(upp_factor) if factor > low_factor: - raise ValueError( - f"No integer found between bounds [{low_factor=}, {upp_factor=}]" - ) + raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]") if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: - out_of_range_str = ( - f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" - ) - in_range_str = ( - f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" - ) + out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" + in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" logging.warning( f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, " @@ -619,9 +589,7 @@ class FeetechMotorsBus: # A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096. self.calibration["homing_offset"][calib_idx] += resolution * factor - def revert_calibration( - self, values: np.ndarray | list, motor_names: list[str] | None - ): + def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): """Inverse of `apply_calibration`.""" if motor_names is None: motor_names = self.motor_names @@ -697,9 +665,7 @@ class FeetechMotorsBus: return values - def read_with_motor_ids( - self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY - ): + def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY): if self.mock: import tests.motors.mock_scservo_sdk as scs else: @@ -808,9 +774,7 @@ class FeetechMotorsBus: values = self.apply_calibration_autocorrect(values, motor_names) # log the number of seconds it took to read the data from the motors - delta_ts_name = get_log_name( - "delta_timestamp_s", "read", data_name, motor_names - ) + delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names) self.logs[delta_ts_name] = time.perf_counter() - start_time # log the utc time at which the data was received @@ -819,9 +783,7 @@ class FeetechMotorsBus: return values - def write_with_motor_ids( - self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY - ): + def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY): if self.mock: import tests.motors.mock_scservo_sdk as scs else: @@ -916,9 +878,7 @@ class FeetechMotorsBus: ) # log the number of seconds it took to write the data to the motors - delta_ts_name = get_log_name( - "delta_timestamp_s", "write", data_name, motor_names - ) + delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names) self.logs[delta_ts_name] = time.perf_counter() - start_time # TODO(rcadene): should we log the time before sending the write command? diff --git a/lerobot/common/robot_devices/robots/configs.py b/lerobot/common/robot_devices/robots/configs.py index d0a7f40a..e940b442 100644 --- a/lerobot/common/robot_devices/robots/configs.py +++ b/lerobot/common/robot_devices/robots/configs.py @@ -69,13 +69,9 @@ class ManipulatorRobotConfig(RobotConfig): if not cam.mock: cam.mock = True - if self.max_relative_target is not None and isinstance( - self.max_relative_target, Sequence - ): + if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence): for name in self.follower_arms: - if len(self.follower_arms[name].motors) != len( - self.max_relative_target - ): + if len(self.follower_arms[name].motors) != len(self.max_relative_target): raise ValueError( f"len(max_relative_target)={len(self.max_relative_target)} but the follower arm with name {name} has " f"{len(self.follower_arms[name].motors)} motors. Please make sure that the " diff --git a/lerobot/common/robot_devices/robots/dynamixel_calibration.py b/lerobot/common/robot_devices/robots/dynamixel_calibration.py index 8eb60c9d..98fe8754 100644 --- a/lerobot/common/robot_devices/robots/dynamixel_calibration.py +++ b/lerobot/common/robot_devices/robots/dynamixel_calibration.py @@ -24,7 +24,9 @@ from lerobot.common.robot_devices.motors.dynamixel import ( ) from lerobot.common.robot_devices.motors.utils import MotorsBus -URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp" +URL_TEMPLATE = ( + "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp" +) # The following positions are provided in nominal degree range ]-180, +180[ # For more info on these constants, see comments in the code where they get used. @@ -35,9 +37,7 @@ ROTATED_POSITION_DEGREE = 90 def assert_drive_mode(drive_mode): # `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted. if not np.all(np.isin(drive_mode, [0, 1])): - raise ValueError( - f"`drive_mode` contains values other than 0 or 1: ({drive_mode})" - ) + raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})") def apply_drive_mode(position, drive_mode): @@ -78,16 +78,12 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type ``` """ if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): - raise ValueError( - "To run calibration, the torque must be disabled on all motors." - ) + raise ValueError("To run calibration, the torque must be disabled on all motors.") print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") print("\nMove arm to zero position") - print( - "See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero") - ) + print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")) input("Press Enter to continue...") # We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed. @@ -108,15 +104,10 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type # corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view # of the previous motor in the kinetic chain. print("\nMove arm to rotated target position") - print( - "See: " - + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated") - ) + print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")) input("Press Enter to continue...") - rotated_target_pos = convert_degrees_to_steps( - ROTATED_POSITION_DEGREE, arm.motor_models - ) + rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models) # Find drive mode by rotating each motor by a quarter of a turn. # Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0). @@ -125,15 +116,11 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type # Re-compute homing offset to take into account drive mode rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode) - rotated_nearest_pos = compute_nearest_rounded_position( - rotated_drived_pos, arm.motor_models - ) + rotated_nearest_pos = compute_nearest_rounded_position(rotated_drived_pos, arm.motor_models) homing_offset = rotated_target_pos - rotated_nearest_pos print("\nMove arm to rest position") - print( - "See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest") - ) + print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")) input("Press Enter to continue...") print() diff --git a/lerobot/common/robot_devices/robots/feetech_calibration.py b/lerobot/common/robot_devices/robots/feetech_calibration.py index f3a59a0b..1969c640 100644 --- a/lerobot/common/robot_devices/robots/feetech_calibration.py +++ b/lerobot/common/robot_devices/robots/feetech_calibration.py @@ -26,7 +26,9 @@ from lerobot.common.robot_devices.motors.feetech import ( ) from lerobot.common.robot_devices.motors.utils import MotorsBus -URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp" +URL_TEMPLATE = ( + "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp" +) # The following positions are provided in nominal degree range ]-180, +180[ # For more info on these constants, see comments in the code where they get used. @@ -37,9 +39,7 @@ ROTATED_POSITION_DEGREE = 90 def assert_drive_mode(drive_mode): # `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted. if not np.all(np.isin(drive_mode, [0, 1])): - raise ValueError( - f"`drive_mode` contains values other than 0 or 1: ({drive_mode})" - ) + raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})") def apply_drive_mode(position, drive_mode): @@ -140,9 +140,7 @@ def apply_offset(calib, offset): return calib -def run_arm_auto_calibration( - arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str -): +def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): if robot_type == "so100": return run_arm_auto_calibration_so100(arm, robot_type, arm_name, arm_type) elif robot_type == "moss": @@ -151,27 +149,18 @@ def run_arm_auto_calibration( raise ValueError(robot_type) -def run_arm_auto_calibration_so100( - arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str -): +def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): """All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms""" if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): - raise ValueError( - "To run calibration, the torque must be disabled on all motors." - ) + raise ValueError("To run calibration, the torque must be disabled on all motors.") if not (robot_type == "so100" and arm_type == "follower"): - raise NotImplementedError( - "Auto calibration only supports the follower of so100 arms for now." - ) + raise NotImplementedError("Auto calibration only supports the follower of so100 arms for now.") print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") print("\nMove arm to initial position") - print( - "See: " - + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial") - ) + print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")) input("Press Enter to continue...") # Lower the acceleration of the motors (in [0,254]) @@ -225,9 +214,7 @@ def run_arm_auto_calibration_so100( ) calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024) - arm.write( - "Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex" - ) + arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex") time.sleep(1) def in_between_move_hook(): @@ -261,13 +248,9 @@ def run_arm_auto_calibration_so100( "shoulder_lift", ) time.sleep(2) - arm.write( - "Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex" - ) + arm.write("Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex") time.sleep(2) - arm.write( - "Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex" - ) + arm.write("Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex") time.sleep(2) arm.write("Goal_Position", round(calib["gripper"]["end_pos"]), "gripper") time.sleep(2) @@ -288,9 +271,7 @@ def run_arm_auto_calibration_so100( arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex") time.sleep(1) arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 2048, "elbow_flex") - arm.write( - "Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift" - ) + arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift") time.sleep(1) arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan") time.sleep(1) @@ -319,27 +300,18 @@ def run_arm_auto_calibration_so100( return calib_dict -def run_arm_auto_calibration_moss( - arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str -): +def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): """All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms""" if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): - raise ValueError( - "To run calibration, the torque must be disabled on all motors." - ) + raise ValueError("To run calibration, the torque must be disabled on all motors.") if not (robot_type == "moss" and arm_type == "follower"): - raise NotImplementedError( - "Auto calibration only supports the follower of moss arms for now." - ) + raise NotImplementedError("Auto calibration only supports the follower of moss arms for now.") print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") print("\nMove arm to initial position") - print( - "See: " - + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial") - ) + print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")) input("Press Enter to continue...") # Lower the acceleration of the motors (in [0,254]) @@ -423,12 +395,8 @@ def run_arm_auto_calibration_moss( arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex") time.sleep(1) - arm.write( - "Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift" - ) - arm.write( - "Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex" - ) + arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift") + arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex") time.sleep(2) calib_modes = [] @@ -455,9 +423,7 @@ def run_arm_auto_calibration_moss( return calib_dict -def run_arm_manual_calibration( - arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str -): +def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): """This function ensures that a neural network trained on data collected on a given robot can work on another robot. For instance before calibration, setting a same goal position for each motor of two different robots will get two very different positions. But after calibration, @@ -480,16 +446,12 @@ def run_arm_manual_calibration( ``` """ if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): - raise ValueError( - "To run calibration, the torque must be disabled on all motors." - ) + raise ValueError("To run calibration, the torque must be disabled on all motors.") print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") print("\nMove arm to zero position") - print( - "See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero") - ) + print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")) input("Press Enter to continue...") # We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed. @@ -509,15 +471,10 @@ def run_arm_manual_calibration( # corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view # of the previous motor in the kinetic chain. print("\nMove arm to rotated target position") - print( - "See: " - + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated") - ) + print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")) input("Press Enter to continue...") - rotated_target_pos = convert_degrees_to_steps( - ROTATED_POSITION_DEGREE, arm.motor_models - ) + rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models) # Find drive mode by rotating each motor by a quarter of a turn. # Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0). @@ -529,9 +486,7 @@ def run_arm_manual_calibration( homing_offset = rotated_target_pos - rotated_drived_pos print("\nMove arm to rest position") - print( - "See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest") - ) + print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")) input("Press Enter to continue...") print() diff --git a/lerobot/common/robot_devices/robots/lekiwi_remote.py b/lerobot/common/robot_devices/robots/lekiwi_remote.py index a1ad4c9f..976fc000 100644 --- a/lerobot/common/robot_devices/robots/lekiwi_remote.py +++ b/lerobot/common/robot_devices/robots/lekiwi_remote.py @@ -42,9 +42,7 @@ def run_camera_capture(cameras, images_lock, latest_images_dict, stop_event): local_dict = {} for name, cam in cameras.items(): frame = cam.async_read() - ret, buffer = cv2.imencode( - ".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90] - ) + ret, buffer = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90]) if ret: local_dict[name] = base64.b64encode(buffer).decode("utf-8") else: @@ -76,9 +74,7 @@ def calibrate_follower_arm(motors_bus, calib_dir_str): print(f"[INFO] Loaded calibration from {calib_file}") else: print("[INFO] Calibration file not found. Running manual calibration...") - calibration = run_arm_manual_calibration( - motors_bus, "lekiwi", "follower_arm", "follower" - ) + calibration = run_arm_manual_calibration(motors_bus, "lekiwi", "follower_arm", "follower") print(f"[INFO] Calibration complete. Saving to {calib_file}") with open(calib_file, "w") as f: json.dump(calibration, f) @@ -174,9 +170,7 @@ def run_lekiwi(robot_config): f"[WARNING] Received {len(arm_positions)} arm positions, expected {len(arm_motor_ids)}" ) else: - for motor, pos in zip( - arm_motor_ids, arm_positions, strict=False - ): + for motor, pos in zip(arm_motor_ids, arm_positions, strict=False): motors_bus.write("Goal_Position", pos, motor) # Process wheel (base) commands. if "raw_velocity" in data: @@ -207,9 +201,7 @@ def run_lekiwi(robot_config): try: pos = motors_bus.read("Present_Position", motor) # Convert the position to a float (or use as is if already numeric). - follower_arm_state.append( - float(pos) if not isinstance(pos, (int, float)) else pos - ) + follower_arm_state.append(float(pos) if not isinstance(pos, (int, float)) else pos) except Exception as e: print(f"[ERROR] Reading motor {motor} failed: {e}") diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py index fbeaad5e..e7993621 100644 --- a/lerobot/common/robot_devices/robots/manipulator.py +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -285,9 +285,7 @@ class ManipulatorRobot: # to squeeze the gripper and have it spring back to an open position on its own. for name in self.leader_arms: self.leader_arms[name].write("Torque_Enable", 1, "gripper") - self.leader_arms[name].write( - "Goal_Position", self.config.gripper_open_degree, "gripper" - ) + self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper") # Check both arms can be read for name in self.follower_arms: @@ -323,22 +321,16 @@ class ManipulatorRobot: run_arm_calibration, ) - calibration = run_arm_calibration( - arm, self.robot_type, name, arm_type - ) + calibration = run_arm_calibration(arm, self.robot_type, name, arm_type) elif self.robot_type in ["so100", "moss", "lekiwi"]: from lerobot.common.robot_devices.robots.feetech_calibration import ( run_arm_manual_calibration, ) - calibration = run_arm_manual_calibration( - arm, self.robot_type, name, arm_type - ) + calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type) - print( - f"Calibration is done! Saving calibration file '{arm_calib_path}'" - ) + print(f"Calibration is done! Saving calibration file '{arm_calib_path}'") arm_calib_path.parent.mkdir(parents=True, exist_ok=True) with open(arm_calib_path, "w") as f: json.dump(calibration, f) @@ -357,17 +349,13 @@ class ManipulatorRobot: from lerobot.common.robot_devices.motors.dynamixel import TorqueMode if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): - raise ValueError( - "To run set robot preset, the torque must be disabled on all motors." - ) + raise ValueError("To run set robot preset, the torque must be disabled on all motors.") # Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't # rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm, # you could end up with a servo with a position 0 or 4095 at a crucial point See [ # https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11] - all_motors_except_gripper = [ - name for name in arm.motor_names if name != "gripper" - ] + all_motors_except_gripper = [name for name in arm.motor_names if name != "gripper"] if len(all_motors_except_gripper) > 0: # 4 corresponds to Extended Position on Koch motors arm.write("Operating_Mode", 4, all_motors_except_gripper) @@ -396,9 +384,7 @@ class ManipulatorRobot: # Enable torque on the gripper of the leader arms, and move it to 45 degrees, # so that we can use it as a trigger to close the gripper of the follower arms. self.leader_arms[name].write("Torque_Enable", 1, "gripper") - self.leader_arms[name].write( - "Goal_Position", self.config.gripper_open_degree, "gripper" - ) + self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper") def set_aloha_robot_preset(self): def set_shadow_(arm): @@ -428,15 +414,11 @@ class ManipulatorRobot: # you could end up with a servo with a position 0 or 4095 at a crucial point See [ # https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11] all_motors_except_gripper = [ - name - for name in self.follower_arms[name].motor_names - if name != "gripper" + name for name in self.follower_arms[name].motor_names if name != "gripper" ] if len(all_motors_except_gripper) > 0: # 4 corresponds to Extended Position on Aloha motors - self.follower_arms[name].write( - "Operating_Mode", 4, all_motors_except_gripper - ) + self.follower_arms[name].write("Operating_Mode", 4, all_motors_except_gripper) # Use 'position control current based' for follower gripper to be limited by the limit of the current. # It can grasp an object without forcing too much even tho, @@ -484,9 +466,7 @@ class ManipulatorRobot: before_lread_t = time.perf_counter() leader_pos[name] = self.leader_arms[name].read("Present_Position") leader_pos[name] = torch.from_numpy(leader_pos[name]) - self.logs[f"read_leader_{name}_pos_dt_s"] = ( - time.perf_counter() - before_lread_t - ) + self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t # Send goal position to the follower follower_goal_pos = {} @@ -507,18 +487,14 @@ class ManipulatorRobot: if self.config.max_relative_target is not None: present_pos = self.follower_arms[name].read("Present_Position") present_pos = torch.from_numpy(present_pos) - goal_pos = ensure_safe_goal_position( - goal_pos, present_pos, self.config.max_relative_target - ) + goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target) # Used when record_data=True follower_goal_pos[name] = goal_pos goal_pos = goal_pos.numpy().astype(np.float32) self.follower_arms[name].write("Goal_Position", goal_pos) - self.logs[f"write_follower_{name}_goal_pos_dt_s"] = ( - time.perf_counter() - before_fwrite_t - ) + self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t # Early exit when recording data is not requested if not record_data: @@ -531,9 +507,7 @@ class ManipulatorRobot: before_fread_t = time.perf_counter() follower_pos[name] = self.follower_arms[name].read("Present_Position") follower_pos[name] = torch.from_numpy(follower_pos[name]) - self.logs[f"read_follower_{name}_pos_dt_s"] = ( - time.perf_counter() - before_fread_t - ) + self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t # Create state by concatenating follower current position state = [] @@ -555,12 +529,8 @@ class ManipulatorRobot: before_camread_t = time.perf_counter() images[name] = self.cameras[name].async_read() images[name] = torch.from_numpy(images[name]) - self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[ - "delta_timestamp_s" - ] - self.logs[f"async_read_camera_{name}_dt_s"] = ( - time.perf_counter() - before_camread_t - ) + self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] + self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t # Populate output dictionaries obs_dict, action_dict = {}, {} @@ -584,9 +554,7 @@ class ManipulatorRobot: before_fread_t = time.perf_counter() follower_pos[name] = self.follower_arms[name].read("Present_Position") follower_pos[name] = torch.from_numpy(follower_pos[name]) - self.logs[f"read_follower_{name}_pos_dt_s"] = ( - time.perf_counter() - before_fread_t - ) + self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t # Create state by concatenating follower current position state = [] @@ -601,12 +569,8 @@ class ManipulatorRobot: before_camread_t = time.perf_counter() images[name] = self.cameras[name].async_read() images[name] = torch.from_numpy(images[name]) - self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[ - "delta_timestamp_s" - ] - self.logs[f"async_read_camera_{name}_dt_s"] = ( - time.perf_counter() - before_camread_t - ) + self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] + self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t # Populate output dictionaries and format to pytorch obs_dict = {} @@ -652,9 +616,7 @@ class ManipulatorRobot: if self.config.max_relative_target is not None: present_pos = self.follower_arms[name].read("Present_Position") present_pos = torch.from_numpy(present_pos) - goal_pos = ensure_safe_goal_position( - goal_pos, present_pos, self.config.max_relative_target - ) + goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target) # Save tensor to concat and return action_sent.append(goal_pos) diff --git a/lerobot/common/robot_devices/robots/mobile_manipulator.py b/lerobot/common/robot_devices/robots/mobile_manipulator.py index bf12c1d5..1ecbf12d 100644 --- a/lerobot/common/robot_devices/robots/mobile_manipulator.py +++ b/lerobot/common/robot_devices/robots/mobile_manipulator.py @@ -271,9 +271,7 @@ class MobileManipulator: calibration = json.load(f) else: print(f"Missing calibration file '{arm_calib_path}'") - calibration = run_arm_manual_calibration( - arm, self.robot_type, name, arm_type - ) + calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type) print(f"Calibration is done! Saving calibration file '{arm_calib_path}'") arm_calib_path.parent.mkdir(parents=True, exist_ok=True) with open(arm_calib_path, "w") as f: @@ -303,9 +301,7 @@ class MobileManipulator: bus.write("Torque_Enable", 0, motor_id) # Then filter out wheels - arm_only_dict = { - k: v for k, v in bus.motors.items() if not k.startswith("wheel_") - } + arm_only_dict = {k: v for k, v in bus.motors.items() if not k.startswith("wheel_")} if not arm_only_dict: continue @@ -377,9 +373,7 @@ class MobileManipulator: if new_arm_state is not None and frames is not None: self.last_frames = frames - remote_arm_state_tensor = torch.tensor( - new_arm_state, dtype=torch.float32 - ) + remote_arm_state_tensor = torch.tensor(new_arm_state, dtype=torch.float32) self.last_remote_arm_state = remote_arm_state_tensor present_speed = new_speed @@ -405,10 +399,7 @@ class MobileManipulator: def _process_present_speed(self, present_speed: dict) -> torch.Tensor: state_tensor = torch.zeros(3, dtype=torch.int32) if present_speed: - decoded = { - key: MobileManipulator.raw_to_degps(value) - for key, value in present_speed.items() - } + decoded = {key: MobileManipulator.raw_to_degps(value) for key, value in present_speed.items()} if "1" in decoded: state_tensor[0] = decoded["1"] if "2" in decoded: @@ -421,9 +412,7 @@ class MobileManipulator: self, record_data: bool = False ) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: if not self.is_connected: - raise RobotDeviceNotConnectedError( - "MobileManipulator is not connected. Run `connect()` first." - ) + raise RobotDeviceNotConnectedError("MobileManipulator is not connected. Run `connect()` first.") speed_setting = self.speed_levels[self.speed_index] xy_speed = speed_setting["xy"] # e.g. 0.1, 0.25, or 0.4 @@ -495,9 +484,7 @@ class MobileManipulator: body_state[2], ) # Convert x,y to mm/s wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32) - combined_state_tensor = torch.cat( - (remote_arm_state_tensor, wheel_state_tensor), dim=0 - ) + combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0) obs_dict = {"observation.state": combined_state_tensor} diff --git a/lerobot/common/robot_devices/robots/stretch.py b/lerobot/common/robot_devices/robots/stretch.py index 813732f0..9cfe6e49 100644 --- a/lerobot/common/robot_devices/robots/stretch.py +++ b/lerobot/common/robot_devices/robots/stretch.py @@ -52,9 +52,7 @@ class StretchRobot(StretchAPI): def connect(self) -> None: self.is_connected = self.startup() if not self.is_connected: - print( - "Another process is already using Stretch. Try running 'stretch_free_robot_process.py'" - ) + print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'") raise ConnectionError() for name in self.cameras: @@ -62,9 +60,7 @@ class StretchRobot(StretchAPI): self.is_connected = self.is_connected and self.cameras[name].is_connected if not self.is_connected: - print( - "Could not connect to the cameras, check that all cameras are plugged-in." - ) + print("Could not connect to the cameras, check that all cameras are plugged-in.") raise ConnectionError() self.run_calibration() @@ -109,12 +105,8 @@ class StretchRobot(StretchAPI): before_camread_t = time.perf_counter() images[name] = self.cameras[name].async_read() images[name] = torch.from_numpy(images[name]) - self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[ - "delta_timestamp_s" - ] - self.logs[f"async_read_camera_{name}_dt_s"] = ( - time.perf_counter() - before_camread_t - ) + self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] + self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t # Populate output dictionaries obs_dict, action_dict = {}, {} @@ -158,12 +150,8 @@ class StretchRobot(StretchAPI): before_camread_t = time.perf_counter() images[name] = self.cameras[name].async_read() images[name] = torch.from_numpy(images[name]) - self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[ - "delta_timestamp_s" - ] - self.logs[f"async_read_camera_{name}_dt_s"] = ( - time.perf_counter() - before_camread_t - ) + self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] + self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t # Populate output dictionaries obs_dict = {} diff --git a/lerobot/common/utils/hub.py b/lerobot/common/utils/hub.py index 972b5b2b..df7435c0 100644 --- a/lerobot/common/utils/hub.py +++ b/lerobot/common/utils/hub.py @@ -69,9 +69,7 @@ class HubMixin: if push_to_hub: if repo_id is None: repo_id = save_directory.name # Defaults to `save_directory` name - return self.push_to_hub( - repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs - ) + return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs) return None def _save_pretrained(self, save_directory: Path) -> None: @@ -177,9 +175,7 @@ class HubMixin: The url of the commit of your object in the given repository. """ api = HfApi(token=token) - repo_id = api.create_repo( - repo_id=repo_id, private=private, exist_ok=True - ).repo_id + repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id if commit_message is None: if "Policy" in self.__class__.__name__: diff --git a/lerobot/common/utils/import_utils.py b/lerobot/common/utils/import_utils.py index e2ce5a87..cd5f8245 100644 --- a/lerobot/common/utils/import_utils.py +++ b/lerobot/common/utils/import_utils.py @@ -17,9 +17,7 @@ import importlib import logging -def is_package_available( - pkg_name: str, return_version: bool = False -) -> tuple[bool, str] | bool: +def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool: """Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py Check if the package spec exists and grab its version to avoid importing a local directory. **Note:** this doesn't work for all packages. diff --git a/lerobot/common/utils/io_utils.py b/lerobot/common/utils/io_utils.py index 339d864c..c67d8e1e 100644 --- a/lerobot/common/utils/io_utils.py +++ b/lerobot/common/utils/io_utils.py @@ -20,16 +20,7 @@ from typing import TypeVar import imageio -JsonLike = ( - str - | int - | float - | bool - | None - | list["JsonLike"] - | dict[str, "JsonLike"] - | tuple["JsonLike", ...] -) +JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...] T = TypeVar("T", bound=JsonLike) @@ -85,9 +76,7 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T: # Check length if len(target) != len(source): - raise ValueError( - f"List length mismatch: expected {len(target)}, got {len(source)}" - ) + raise ValueError(f"List length mismatch: expected {len(target)}, got {len(source)}") # Recursively update each element. for i in range(len(target)): @@ -99,14 +88,10 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T: # which we'll convert back to a tuple. elif isinstance(target, tuple): if not isinstance(source, list): - raise TypeError( - f"Type mismatch: expected list (for tuple), got {type(source)}" - ) + raise TypeError(f"Type mismatch: expected list (for tuple), got {type(source)}") if len(target) != len(source): - raise ValueError( - f"Tuple length mismatch: expected {len(target)}, got {len(source)}" - ) + raise ValueError(f"Tuple length mismatch: expected {len(target)}, got {len(source)}") # Convert each element, forming a new tuple. converted_items = [] @@ -120,9 +105,7 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T: else: # Check the exact type. If these must match 1:1, do: if type(target) is not type(source): - raise TypeError( - f"Type mismatch: expected {type(target)}, got {type(source)}" - ) + raise TypeError(f"Type mismatch: expected {type(target)}, got {type(source)}") return source # Perform the in-place/recursive deserialization diff --git a/lerobot/common/utils/logging_utils.py b/lerobot/common/utils/logging_utils.py index 58cb2453..56c9abb2 100644 --- a/lerobot/common/utils/logging_utils.py +++ b/lerobot/common/utils/logging_utils.py @@ -107,17 +107,13 @@ class MetricsTracker: self.episodes = self.samples / self._avg_samples_per_ep self.epochs = self.samples / self._num_frames - def __getattr__( - self, name: str - ) -> int | dict[str, AverageMeter] | AverageMeter | Any: + def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any: if name in self.__dict__: return self.__dict__[name] elif name in self.metrics: return self.metrics[name] else: - raise AttributeError( - f"'{self.__class__.__name__}' object has no attribute '{name}'" - ) + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") def __setattr__(self, name: str, value: Any) -> None: if name in self.__dict__: @@ -125,9 +121,7 @@ class MetricsTracker: elif name in self.metrics: self.metrics[name].update(value) else: - raise AttributeError( - f"'{self.__class__.__name__}' object has no attribute '{name}'" - ) + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") def step(self) -> None: """ diff --git a/lerobot/common/utils/random_utils.py b/lerobot/common/utils/random_utils.py index c9327125..69e5e8f9 100644 --- a/lerobot/common/utils/random_utils.py +++ b/lerobot/common/utils/random_utils.py @@ -123,9 +123,7 @@ def deserialize_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None: """ py_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("py")} np_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("np")} - torch_rng_state_dict = { - k: v for k, v in rng_state_dict.items() if k.startswith("torch") - } + torch_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("torch")} deserialize_python_rng_state(py_rng_state_dict) deserialize_numpy_rng_state(np_rng_state_dict) diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index e0b6aafc..ebdc29e9 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -48,9 +48,7 @@ def auto_select_torch_device() -> torch.device: logging.info("Metal backend detected, using cuda.") return torch.device("mps") else: - logging.warning( - "No accelerated backend detected. Using default cpu, this will be slow." - ) + logging.warning("No accelerated backend detected. Using default cpu, this will be slow.") return torch.device("cpu") @@ -98,9 +96,7 @@ def is_torch_device_available(try_device: str) -> bool: elif try_device == "cpu": return True else: - raise ValueError( - f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu." - ) + raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.") def is_amp_available(device: str): @@ -158,10 +154,7 @@ def _relative_path_between(path1: Path, path2: Path) -> Path: except ValueError: # most likely because path1 is not a subpath of path2 common_parts = Path(osp.commonpath([path1, path2])).parts return Path( - "/".join( - [".."] * (len(path2.parts) - len(common_parts)) - + list(path1.parts[len(common_parts) :]) - ) + "/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :])) ) @@ -172,26 +165,10 @@ def print_cuda_memory_usage(): gc.collect() # Also clear the cache if you want to fully release the memory torch.cuda.empty_cache() - print( - "Current GPU Memory Allocated: {:.2f} MB".format( - torch.cuda.memory_allocated(0) / 1024**2 - ) - ) - print( - "Maximum GPU Memory Allocated: {:.2f} MB".format( - torch.cuda.max_memory_allocated(0) / 1024**2 - ) - ) - print( - "Current GPU Memory Reserved: {:.2f} MB".format( - torch.cuda.memory_reserved(0) / 1024**2 - ) - ) - print( - "Maximum GPU Memory Reserved: {:.2f} MB".format( - torch.cuda.max_memory_reserved(0) / 1024**2 - ) - ) + print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2)) + print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2)) + print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2)) + print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2)) def capture_timestamp_utc(): @@ -223,9 +200,7 @@ def say(text, blocking=False): if blocking: subprocess.run(cmd, check=True) else: - subprocess.Popen( - cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0 - ) + subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0) def log_say(text, play_sounds, blocking=False): diff --git a/lerobot/common/utils/wandb_utils.py b/lerobot/common/utils/wandb_utils.py index ca4c702f..3fe241d4 100644 --- a/lerobot/common/utils/wandb_utils.py +++ b/lerobot/common/utils/wandb_utils.py @@ -26,9 +26,7 @@ from lerobot.common.constants import PRETRAINED_MODEL_DIR from lerobot.configs.train import TrainPipelineConfig -def cfg_to_group( - cfg: TrainPipelineConfig, return_list: bool = False -) -> list[str] | str: +def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str: """Return a group name for logging. Optionally returns group name as list.""" lst = [ f"policy:{cfg.policy.type}", @@ -95,9 +93,7 @@ class WandBLogger: mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online", ) print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) - logging.info( - f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}" - ) + logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") self._wandb = wandb def log_policy(self, checkpoint_dir: Path): @@ -109,9 +105,7 @@ class WandBLogger: artifact_name = f"{self._group}-{step_id}" artifact_name = get_safe_wandb_artifact_name(artifact_name) artifact = self._wandb.Artifact(artifact_name, type="model") - artifact.add_file( - checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE - ) + artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE) self._wandb.log_artifact(artifact) def log_dict(self, d: dict, step: int, mode: str = "train"): diff --git a/lerobot/configs/default.py b/lerobot/configs/default.py index 4bcde16f..ce72466a 100644 --- a/lerobot/configs/default.py +++ b/lerobot/configs/default.py @@ -33,9 +33,7 @@ class DatasetConfig: # Root directory where the dataset will be stored (e.g. 'dataset/path'). root: str | None = None episodes: list[int] | None = None - image_transforms: ImageTransformsConfig = field( - default_factory=ImageTransformsConfig - ) + image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig) revision: str | None = None use_imagenet_stats: bool = True video_backend: str = field(default_factory=get_safe_default_codec) diff --git a/lerobot/configs/eval.py b/lerobot/configs/eval.py index f1a26d47..16b35291 100644 --- a/lerobot/configs/eval.py +++ b/lerobot/configs/eval.py @@ -40,9 +40,7 @@ class EvalPipelineConfig: policy_path = parser.get_path_arg("policy") if policy_path: cli_overrides = parser.get_cli_overrides("policy") - self.policy = PreTrainedConfig.from_pretrained( - policy_path, cli_overrides=cli_overrides - ) + self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) self.policy.pretrained_path = policy_path else: diff --git a/lerobot/configs/parser.py b/lerobot/configs/parser.py index 3cdf0578..f4d52f5c 100644 --- a/lerobot/configs/parser.py +++ b/lerobot/configs/parser.py @@ -29,9 +29,7 @@ PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path" draccus.set_config_type("json") -def get_cli_overrides( - field_name: str, args: Sequence[str] | None = None -) -> list[str] | None: +def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None: """Parses arguments from cli at a given nested attribute level. For example, supposing the main script was called with: @@ -158,9 +156,7 @@ def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[ return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")] -def filter_path_args( - fields_to_filter: str | list[str], args: Sequence[str] | None = None -) -> list[str]: +def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | None = None) -> list[str]: """ Filters command-line arguments related to fields with specific path arguments. @@ -188,9 +184,7 @@ def filter_path_args( argument=None, message=f"Cannot specify both --{field}.{PATH_KEY} and --{field}.{draccus.CHOICE_TYPE_KEY}", ) - filtered_args = [ - arg for arg in filtered_args if not arg.startswith(f"--{field}.") - ] + filtered_args = [arg for arg in filtered_args if not arg.startswith(f"--{field}.")] return filtered_args @@ -222,9 +216,7 @@ def wrap(config_path: Path | None = None): load_plugin(plugin_path) except PluginLoadError as e: # add the relevant CLI arg to the error message - raise PluginLoadError( - f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}" - ) from e + raise PluginLoadError(f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}") from e cli_args = filter_arg(plugin_cli_arg, cli_args) config_path_cli = parse_arg("config_path", cli_args) if has_method(argtype, "__get_path_fields__"): @@ -234,9 +226,7 @@ def wrap(config_path: Path | None = None): cli_args = filter_arg("config_path", cli_args) cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args) else: - cfg = draccus.parse( - config_class=argtype, config_path=config_path, args=cli_args - ) + cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args) response = fn(cfg, *args, **kwargs) return response diff --git a/lerobot/configs/policies.py b/lerobot/configs/policies.py index 32c3125f..6218d39a 100644 --- a/lerobot/configs/policies.py +++ b/lerobot/configs/policies.py @@ -68,9 +68,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): self.pretrained_path = None if not self.device or not is_torch_device_available(self.device): auto_device = auto_select_torch_device() - logging.warning( - f"Device '{self.device}' is not available. Switching to '{auto_device}'." - ) + logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.") self.device = auto_device.type # Automatically deactivate AMP if necessary @@ -124,11 +122,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): @property def image_features(self) -> dict[str, PolicyFeature]: - return { - key: ft - for key, ft in self.input_features.items() - if ft.type is FeatureType.VISUAL - } + return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL} @property def action_feature(self) -> PolicyFeature | None: diff --git a/lerobot/configs/train.py b/lerobot/configs/train.py index b4adeb89..1e2f6544 100644 --- a/lerobot/configs/train.py +++ b/lerobot/configs/train.py @@ -73,9 +73,7 @@ class TrainPipelineConfig(HubMixin): if policy_path: # Only load the policy config cli_overrides = parser.get_cli_overrides("policy") - self.policy = PreTrainedConfig.from_pretrained( - policy_path, cli_overrides=cli_overrides - ) + self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) self.policy.pretrained_path = policy_path elif self.resume: # The entire train config is already loaded, we just need to get the checkpoint dir @@ -99,11 +97,7 @@ class TrainPipelineConfig(HubMixin): else: self.job_name = f"{self.env.type}_{self.policy.type}" - if ( - not self.resume - and isinstance(self.output_dir, Path) - and self.output_dir.is_dir() - ): + if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir(): raise FileExistsError( f"Output directory {self.output_dir} already exists and resume is {self.resume}. " f"Please change your output directory so that {self.output_dir} is not overwritten." @@ -114,16 +108,10 @@ class TrainPipelineConfig(HubMixin): self.output_dir = Path("outputs/train") / train_dir if isinstance(self.dataset.repo_id, list): - raise NotImplementedError( - "LeRobotMultiDataset is not currently implemented." - ) + raise NotImplementedError("LeRobotMultiDataset is not currently implemented.") - if not self.use_policy_training_preset and ( - self.optimizer is None or self.scheduler is None - ): - raise ValueError( - "Optimizer and Scheduler must be set when the policy presets are not used." - ) + if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None): + raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.") elif self.use_policy_training_preset and not self.resume: self.optimizer = self.policy.get_optimizer_preset() self.scheduler = self.policy.get_scheduler_preset() diff --git a/lerobot/scripts/configure_motor.py b/lerobot/scripts/configure_motor.py index 985181a8..60aee664 100644 --- a/lerobot/scripts/configure_motor.py +++ b/lerobot/scripts/configure_motor.py @@ -67,8 +67,8 @@ def get_motor_bus_cls(brand: str) -> tuple: def configure_motor(port, brand, model, motor_idx_des, baudrate_des): - motor_bus_config_cls, motor_bus_cls, model_baudrate_table, series_baudrate_table = ( - get_motor_bus_cls(brand) + motor_bus_config_cls, motor_bus_cls, model_baudrate_table, series_baudrate_table = get_motor_bus_cls( + brand ) # Check if the provided model exists in the model_baud_rate_table @@ -82,9 +82,7 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des): motor_index_arbitrary = motor_idx_des # Use the motor ID passed via argument motor_model = model # Use the motor model passed via argument - config = motor_bus_config_cls( - port=port, motors={motor_name: (motor_index_arbitrary, motor_model)} - ) + config = motor_bus_config_cls(port=port, motors={motor_name: (motor_index_arbitrary, motor_model)}) # Initialize the MotorBus with the correct port and motor configurations motor_bus = motor_bus_cls(config=config) @@ -120,26 +118,20 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des): break if motor_index == -1: - raise ValueError( - "No motors detected. Please ensure you have one motor connected." - ) + raise ValueError("No motors detected. Please ensure you have one motor connected.") print(f"Motor index found at: {motor_index}") if brand == "feetech": # Allows ID and BAUDRATE to be written in memory - motor_bus.write_with_motor_ids( - motor_bus.motor_models, motor_index, "Lock", 0 - ) + motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0) if baudrate != baudrate_des: print(f"Setting its baudrate to {baudrate_des}") baudrate_idx = list(series_baudrate_table.values()).index(baudrate_des) # The write can fail, so we allow retries - motor_bus.write_with_motor_ids( - motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx - ) + motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx) time.sleep(0.5) motor_bus.set_bus_baudrate(baudrate_des) present_baudrate_idx = motor_bus.read_with_motor_ids( @@ -151,16 +143,10 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des): print(f"Setting its index to desired index {motor_idx_des}") if brand == "feetech": - motor_bus.write_with_motor_ids( - motor_bus.motor_models, motor_index, "Lock", 0 - ) - motor_bus.write_with_motor_ids( - motor_bus.motor_models, motor_index, "ID", motor_idx_des - ) + motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0) + motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "ID", motor_idx_des) - present_idx = motor_bus.read_with_motor_ids( - motor_bus.motor_models, motor_idx_des, "ID", num_retry=2 - ) + present_idx = motor_bus.read_with_motor_ids(motor_bus.motor_models, motor_idx_des, "ID", num_retry=2) if present_idx != motor_idx_des: raise OSError("Failed to write index.") @@ -194,12 +180,8 @@ if __name__ == "__main__": required=True, help="Motors bus port (e.g. dynamixel,feetech)", ) - parser.add_argument( - "--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)" - ) - parser.add_argument( - "--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)" - ) + parser.add_argument("--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)") + parser.add_argument("--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)") parser.add_argument( "--ID", type=int, diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index f44f2258..0399c0e1 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -255,8 +255,7 @@ def record( if len(robot.cameras) > 0: dataset.start_image_writer( num_processes=cfg.num_image_writer_processes, - num_threads=cfg.num_image_writer_threads_per_camera - * len(robot.cameras), + num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras), ) sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video) else: @@ -269,19 +268,14 @@ def record( robot=robot, use_videos=cfg.video, image_writer_processes=cfg.num_image_writer_processes, - image_writer_threads=cfg.num_image_writer_threads_per_camera - * len(robot.cameras), + image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras), ) # Load pretrained policy - policy = ( - None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) - ) + policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) # Load pretrained policy - policy = ( - None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) - ) + policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) if not robot.is_connected: robot.connect() diff --git a/lerobot/scripts/control_sim_robot.py b/lerobot/scripts/control_sim_robot.py index 2bf72640..5e62b88d 100644 --- a/lerobot/scripts/control_sim_robot.py +++ b/lerobot/scripts/control_sim_robot.py @@ -174,10 +174,7 @@ def teleoperate(env, robot: Robot, process_action_fn, teleop_time_s=None): leader_pos = robot.leader_arms.main.read("Present_Position") action = process_action_fn(leader_pos) env.step(np.expand_dims(action, 0)) - if ( - teleop_time_s is not None - and time.perf_counter() - start_teleop_t > teleop_time_s - ): + if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s: print("Teleoperation processes finished.") break @@ -209,27 +206,19 @@ def record( # Load pretrained policy extra_features = ( - {"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} - if assign_rewards - else None + {"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None ) policy = None if pretrained_policy_name_or_path is not None: - policy, policy_fps, device, use_amp = init_policy( - pretrained_policy_name_or_path, policy_overrides - ) + policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides) if fps is None: fps = policy_fps - logging.warning( - f"No fps provided, so using the fps from policy config ({policy_fps})." - ) + logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).") if policy is None and process_action_from_leader is None: - raise ValueError( - "Either policy or process_action_fn has to be set to enable control in sim." - ) + raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.") # initialize listener before sim env listener, events = init_keyboard_listener(assign_rewards=assign_rewards) @@ -380,9 +369,7 @@ def record( if events["stop_recording"] or recorded_episodes >= num_episodes: break else: - logging.info( - "Waiting for a few seconds before starting next episode recording..." - ) + logging.info("Waiting for a few seconds before starting next episode recording...") busy_wait(3) log_say("Stop recording", play_sounds, blocking=True) @@ -481,9 +468,7 @@ if __name__ == "__main__": required=True, help="A description of the task preformed during recording that can be used as a language instruction.", ) - parser_record.add_argument( - "--num-episodes", type=int, default=50, help="Number of episodes to record." - ) + parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.") parser_record.add_argument( "--run-compute-stats", type=int, @@ -561,9 +546,7 @@ if __name__ == "__main__": default="lerobot/test", help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).", ) - parser_replay.add_argument( - "--episode", type=int, default=0, help="Index of the episodes to replay." - ) + parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episodes to replay.") args = parser.parse_args() diff --git a/lerobot/scripts/display_sys_info.py b/lerobot/scripts/display_sys_info.py index 2d844990..4d3cc291 100644 --- a/lerobot/scripts/display_sys_info.py +++ b/lerobot/scripts/display_sys_info.py @@ -59,11 +59,7 @@ np_version = np.__version__ if HAS_NP else "N/A" torch_version = torch.__version__ if HAS_TORCH else "N/A" torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A" -cuda_version = ( - torch._C._cuda_getCompiledVersion() - if HAS_TORCH and torch.version.cuda is not None - else "N/A" -) +cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A" # TODO(aliberts): refactor into an actual command `lerobot env` @@ -81,9 +77,7 @@ def display_sys_info() -> dict: "Using GPU in script?": "", # "Using distributed or parallel set-up in script?": "", } - print( - "\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n" - ) + print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n") print(format_dict(info)) return info diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index b82ebb20..7f85c3e5 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -152,8 +152,7 @@ def rollout( all_observations.append(deepcopy(observation)) observation = { - key: observation[key].to(device, non_blocking=device.type == "cuda") - for key in observation + key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation } # Infer "task" from attributes of environments. @@ -175,10 +174,7 @@ def rollout( # VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't # available of none of the envs finished. if "final_info" in info: - successes = [ - info["is_success"] if info is not None else False - for info in info["final_info"] - ] + successes = [info["is_success"] if info is not None else False for info in info["final_info"]] else: successes = [False] * env.num_envs @@ -192,13 +188,9 @@ def rollout( step += 1 running_success_rate = ( - einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any") - .numpy() - .mean() - ) - progbar.set_postfix( - {"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"} + einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean() ) + progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"}) progbar.update() # Track the final observation. @@ -216,9 +208,7 @@ def rollout( if return_observations: stacked_observations = {} for key in all_observations[0]: - stacked_observations[key] = torch.stack( - [obs[key] for obs in all_observations], dim=1 - ) + stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1) ret["observation"] = stacked_observations if hasattr(policy, "use_original_modules"): @@ -280,9 +270,7 @@ def eval_policy( return n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs) if isinstance(env, gym.vector.SyncVectorEnv): - ep_frames.append( - np.stack([env.envs[i].render() for i in range(n_to_render_now)]) - ) # noqa: B023 + ep_frames.append(np.stack([env.envs[i].render() for i in range(n_to_render_now)])) # noqa: B023 elif isinstance(env, gym.vector.AsyncVectorEnv): # Here we must render all frames and discard any we don't need. ep_frames.append(np.stack(env.call("render")[:n_to_render_now])) @@ -294,9 +282,7 @@ def eval_policy( episode_data: dict | None = None # we dont want progress bar when we use slurm, since it clutters the logs - progbar = trange( - n_batches, desc="Stepping through eval batches", disable=inside_slurm() - ) + progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm()) for batch_ix in progbar: # Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout # step. @@ -326,22 +312,13 @@ def eval_policy( # Make a mask with shape (batch, n_steps) to mask out rollout data after the first done # (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step. - mask = ( - torch.arange(n_steps) - <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps) - ).int() + mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int() # Extend metrics. - batch_sum_rewards = einops.reduce( - (rollout_data["reward"] * mask), "b n -> b", "sum" - ) + batch_sum_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "sum") sum_rewards.extend(batch_sum_rewards.tolist()) - batch_max_rewards = einops.reduce( - (rollout_data["reward"] * mask), "b n -> b", "max" - ) + batch_max_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "max") max_rewards.extend(batch_max_rewards.tolist()) - batch_successes = einops.reduce( - (rollout_data["success"] * mask), "b n -> b", "any" - ) + batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any") all_successes.extend(batch_successes.tolist()) if seeds: all_seeds.extend(seeds) @@ -354,27 +331,17 @@ def eval_policy( rollout_data, done_indices, start_episode_index=batch_ix * env.num_envs, - start_data_index=( - 0 - if episode_data is None - else (episode_data["index"][-1].item() + 1) - ), + start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)), fps=env.unwrapped.metadata["render_fps"], ) if episode_data is None: episode_data = this_episode_data else: # Some sanity checks to make sure we are correctly compiling the data. - assert ( - episode_data["episode_index"][-1] + 1 - == this_episode_data["episode_index"][0] - ) + assert episode_data["episode_index"][-1] + 1 == this_episode_data["episode_index"][0] assert episode_data["index"][-1] + 1 == this_episode_data["index"][0] # Concatenate the episode data. - episode_data = { - k: torch.cat([episode_data[k], this_episode_data[k]]) - for k in episode_data - } + episode_data = {k: torch.cat([episode_data[k], this_episode_data[k]]) for k in episode_data} # Maybe render video for visualization. if max_episodes_rendered > 0 and len(ep_frames) > 0: @@ -392,9 +359,7 @@ def eval_policy( target=write_video, args=( str(video_path), - stacked_frames[ - : done_index + 1 - ], # + 1 to capture the last observation + stacked_frames[: done_index + 1], # + 1 to capture the last observation env.unwrapped.metadata["render_fps"], ), ) @@ -403,9 +368,7 @@ def eval_policy( n_episodes_rendered += 1 progbar.set_postfix( - { - "running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%" - } + {"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"} ) # Wait till all video rendering threads are done. @@ -473,16 +436,12 @@ def _compile_episode_data( # Here we do `num_frames - 1` as we don't want to include the last observation frame just yet. ep_dict = { "action": rollout_data["action"][ep_ix, : num_frames - 1], - "episode_index": torch.tensor( - [start_episode_index + ep_ix] * (num_frames - 1) - ), + "episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)), "frame_index": torch.arange(0, num_frames - 1, 1), "timestamp": torch.arange(0, num_frames - 1, 1) / fps, "next.done": rollout_data["done"][ep_ix, : num_frames - 1], "next.success": rollout_data["success"][ep_ix, : num_frames - 1], - "next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type( - torch.float32 - ), + "next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32), } # For the last observation frame, all other keys will just be copy padded. @@ -498,9 +457,7 @@ def _compile_episode_data( for key in ep_dicts[0]: data_dict[key] = torch.cat([x[key] for x in ep_dicts]) - data_dict["index"] = torch.arange( - start_data_index, start_data_index + total_frames, 1 - ) + data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1) return data_dict @@ -516,14 +473,10 @@ def eval_main(cfg: EvalPipelineConfig): torch.backends.cuda.matmul.allow_tf32 = True set_seed(cfg.seed) - logging.info( - colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}" - ) + logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") logging.info("Making environment.") - env = make_env( - cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs - ) + env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) logging.info("Making policy.") @@ -535,9 +488,7 @@ def eval_main(cfg: EvalPipelineConfig): with ( torch.no_grad(), - torch.autocast(device_type=device.type) - if cfg.policy.use_amp - else nullcontext(), + torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(), ): info = eval_policy( env, diff --git a/lerobot/scripts/eval_on_robot.py b/lerobot/scripts/eval_on_robot.py index 8a7062e7..c1e0054b 100644 --- a/lerobot/scripts/eval_on_robot.py +++ b/lerobot/scripts/eval_on_robot.py @@ -74,9 +74,7 @@ def get_classifier(pretrained_path, config_path): cfg = init_hydra_config(config_path) classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg) - classifier_config.num_cameras = len( - cfg.training.image_keys - ) # TODO automate these paths + classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths model = Classifier(classifier_config) model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict()) model = model.to("mps") @@ -161,17 +159,11 @@ def rollout( images = [] for key in image_keys: if display_cameras: - cv2.imshow( - key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR) - ) + cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) cv2.waitKey(1) images.append(observation[key].to("mps")) - reward = ( - reward_classifier.predict_reward(images) - if reward_classifier is not None - else 0.0 - ) + reward = reward_classifier.predict_reward(images) if reward_classifier is not None else 0.0 all_rewards.append(reward) # print("REWARD : ", reward) @@ -235,9 +227,7 @@ def eval_policy( start_eval = time.perf_counter() progbar = trange(n_episodes, desc="Evaluating policy on real robot") - reward_classifier = get_classifier( - reward_classifier_pretrained_path, reward_classifier_config_file - ) + reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file) for _ in progbar: rollout_data = rollout( @@ -313,9 +303,7 @@ def init_keyboard_listener(): print("Right arrow key pressed. Exiting loop...") events["exit_early"] = True elif key == keyboard.Key.left: - print( - "Left arrow key pressed. Exiting loop and rerecord the last episode..." - ) + print("Left arrow key pressed. Exiting loop and rerecord the last episode...") events["rerecord_episode"] = True events["exit_early"] = True elif key == keyboard.Key.space: @@ -380,9 +368,7 @@ if __name__ == "__main__": "debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)." ), ) - parser.add_argument( - "--revision", help="Optionally provide the Hugging Face Hub revision ID." - ) + parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.") parser.add_argument( "--out-dir", help=( diff --git a/lerobot/scripts/find_motors_bus_port.py b/lerobot/scripts/find_motors_bus_port.py index ca56bf48..68f2315d 100644 --- a/lerobot/scripts/find_motors_bus_port.py +++ b/lerobot/scripts/find_motors_bus_port.py @@ -45,13 +45,9 @@ def find_port(): print(f"The port of this MotorsBus is '{port}'") print("Reconnect the USB cable.") elif len(ports_diff) == 0: - raise OSError( - f"Could not detect the port. No difference was found ({ports_diff})." - ) + raise OSError(f"Could not detect the port. No difference was found ({ports_diff}).") else: - raise OSError( - f"Could not detect the port. More than one port was found ({ports_diff})." - ) + raise OSError(f"Could not detect the port. More than one port was found ({ports_diff}).") if __name__ == "__main__": diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py index 8264bf00..e3a47e9b 100644 --- a/lerobot/scripts/server/actor_server.py +++ b/lerobot/scripts/server/actor_server.py @@ -14,18 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from statistics import mean, quantiles +import time from functools import lru_cache -from lerobot.scripts.server.utils import setup_process_handlers +from queue import Empty +from statistics import mean, quantiles # from lerobot.scripts.eval import eval_policy - import grpc import hydra import torch from omegaconf import DictConfig from torch import nn -import time +from torch.multiprocessing import Event, Queue # TODO: Remove the import of maniskill # from lerobot.common.envs.factory import make_maniskill_env @@ -34,34 +34,28 @@ from lerobot.common.policies.factory import make_policy from lerobot.common.policies.sac.modeling_sac import SACPolicy from lerobot.common.robot_devices.robots.factory import make_robot from lerobot.common.robot_devices.robots.utils import Robot +from lerobot.common.robot_devices.utils import busy_wait from lerobot.common.utils.utils import ( TimerManager, get_safe_torch_device, + init_logging, set_global_seed, ) -from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc +from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service from lerobot.scripts.server.buffer import ( Transition, + bytes_to_state_dict, move_state_dict_to_device, move_transition_to_device, python_object_to_bytes, transitions_to_bytes, - bytes_to_state_dict, ) +from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env from lerobot.scripts.server.network_utils import ( receive_bytes_in_chunks, send_bytes_in_chunks, ) -from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env -from lerobot.scripts.server import learner_service -from lerobot.common.robot_devices.utils import busy_wait - -from torch.multiprocessing import Queue, Event -from queue import Empty - -from lerobot.common.utils.utils import init_logging - -from lerobot.scripts.server.utils import get_last_item_from_queue +from lerobot.scripts.server.utils import get_last_item_from_queue, setup_process_handlers ACTOR_SHUTDOWN_TIMEOUT = 30 @@ -102,9 +96,7 @@ def receive_policy( logging.info("[ACTOR] Received policy loop stopped") -def transitions_stream( - shutdown_event: Event, transitions_queue: Queue -) -> hilserl_pb2.Empty: +def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilserl_pb2.Empty: while not shutdown_event.is_set(): try: message = transitions_queue.get(block=True, timeout=5) @@ -169,9 +161,7 @@ def send_transitions( ) try: - learner_client.SendTransitions( - transitions_stream(shutdown_event, transitions_queue) - ) + learner_client.SendTransitions(transitions_stream(shutdown_event, transitions_queue)) except grpc.RpcError as e: logging.error(f"[ACTOR] gRPC error: {e}") @@ -211,9 +201,7 @@ def send_interactions( ) try: - learner_client.SendInteractions( - interactions_stream(shutdown_event, interactions_queue) - ) + learner_client.SendInteractions(interactions_stream(shutdown_event, interactions_queue)) except grpc.RpcError as e: logging.error(f"[ACTOR] gRPC error: {e}") @@ -301,9 +289,7 @@ def act_with_policy( logging.info("make_env online") - online_env = make_robot_env( - robot=robot, reward_classifier=reward_classifier, cfg=cfg - ) + online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg) set_global_seed(cfg.seed) device = get_safe_torch_device(cfg.device, log=True) @@ -355,13 +341,9 @@ def act_with_policy( action = policy.select_action(batch=obs) policy_fps = 1.0 / (list_policy_time[-1] + 1e-9) - log_policy_frequency_issue( - policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step - ) + log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step) - next_obs, reward, done, truncated, info = online_env.step( - action.squeeze(dim=0).cpu().numpy() - ) + next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy()) else: # TODO (azouitine): Make a custom space for torch tensor action = online_env.action_space.sample() @@ -369,9 +351,7 @@ def act_with_policy( # HACK: We have only one env but we want to batch it, it will be resolved with the torch box action = ( - torch.from_numpy(action[0]) - .to(device, non_blocking=device.type == "cuda") - .unsqueeze(dim=0) + torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0) ) sum_reward_episode += float(reward) @@ -391,9 +371,7 @@ def act_with_policy( # Check for NaN values in observations for key, tensor in obs.items(): if torch.isnan(tensor).any(): - logging.error( - f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}" - ) + logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}") list_transition_to_send_to_learner.append( Transition( @@ -413,13 +391,9 @@ def act_with_policy( # Because we are using a single environment we can index at zero if done or truncated: # TODO: Handle logging for episode information - logging.info( - f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}" - ) + logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}") - update_policy_parameters( - policy=policy.actor, parameters_queue=parameters_queue, device=device - ) + update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device) if len(list_transition_to_send_to_learner) > 0: push_transitions_to_transport_queue( @@ -495,9 +469,7 @@ def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]: return stats -def log_policy_frequency_issue( - policy_fps: float, cfg: DictConfig, interaction_step: int -): +def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_step: int): if policy_fps < cfg.fps: logging.warning( f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}" diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py index 8ca14a03..e10ffbdf 100644 --- a/lerobot/scripts/server/buffer.py +++ b/lerobot/scripts/server/buffer.py @@ -14,16 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools +import io +import os +import pickle from typing import Any, Callable, Optional, Sequence, TypedDict -import io import torch import torch.nn.functional as F # noqa: N812 from tqdm import tqdm from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -import os -import pickle class Transition(TypedDict): @@ -45,38 +45,27 @@ class BatchTransition(TypedDict): truncated: torch.Tensor -def move_transition_to_device( - transition: Transition, device: str = "cpu" -) -> Transition: +def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition: # Move state tensors to CPU device = torch.device(device) transition["state"] = { - key: val.to(device, non_blocking=device.type == "cuda") - for key, val in transition["state"].items() + key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items() } # Move action to CPU - transition["action"] = transition["action"].to( - device, non_blocking=device.type == "cuda" - ) + transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda") # No need to move reward or done, as they are float and bool # No need to move reward or done, as they are float and bool if isinstance(transition["reward"], torch.Tensor): - transition["reward"] = transition["reward"].to( - device=device, non_blocking=device.type == "cuda" - ) + transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda") if isinstance(transition["done"], torch.Tensor): - transition["done"] = transition["done"].to( - device, non_blocking=device.type == "cuda" - ) + transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda") if isinstance(transition["truncated"], torch.Tensor): - transition["truncated"] = transition["truncated"].to( - device, non_blocking=device.type == "cuda" - ) + transition["truncated"] = transition["truncated"].to(device, non_blocking=device.type == "cuda") # Move next_state tensors to CPU transition["next_state"] = { @@ -100,10 +89,7 @@ def move_state_dict_to_device(state_dict, device="cpu"): if isinstance(state_dict, torch.Tensor): return state_dict.to(device) elif isinstance(state_dict, dict): - return { - k: move_state_dict_to_device(v, device=device) - for k, v in state_dict.items() - } + return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()} elif isinstance(state_dict, list): return [move_state_dict_to_device(v, device=device) for v in state_dict] elif isinstance(state_dict, tuple): @@ -174,9 +160,7 @@ def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Te images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C) # Gather pixels - cropped_hwcn = images_hwcn[ - torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, : - ] + cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :] # cropped_hwcn => (B, crop_h, crop_w, C) cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w) @@ -223,9 +207,7 @@ class ReplayBuffer: self.optimize_memory = optimize_memory # Track episode boundaries for memory optimization - self.episode_ends = torch.zeros( - capacity, dtype=torch.bool, device=storage_device - ) + self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device) # If no state_keys provided, default to an empty list self.state_keys = state_keys if state_keys is not None else [] @@ -246,9 +228,7 @@ class ReplayBuffer: key: torch.empty((self.capacity, *shape), device=self.storage_device) for key, shape in state_shapes.items() } - self.actions = torch.empty( - (self.capacity, *action_shape), device=self.storage_device - ) + self.actions = torch.empty((self.capacity, *action_shape), device=self.storage_device) self.rewards = torch.empty((self.capacity,), device=self.storage_device) if not self.optimize_memory: @@ -262,12 +242,8 @@ class ReplayBuffer: # Just create a reference to states for consistent API self.next_states = self.states # Just a reference for API consistency - self.dones = torch.empty( - (self.capacity,), dtype=torch.bool, device=self.storage_device - ) - self.truncateds = torch.empty( - (self.capacity,), dtype=torch.bool, device=self.storage_device - ) + self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) + self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) self.initialized = True def __len__(self): @@ -294,9 +270,7 @@ class ReplayBuffer: if not self.optimize_memory: # Only store next_states if not optimizing memory - self.next_states[key][self.position].copy_( - next_state[key].squeeze(dim=0) - ) + self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0)) self.actions[self.position].copy_(action.squeeze(dim=0)) self.rewards[self.position] = reward @@ -309,23 +283,15 @@ class ReplayBuffer: def sample(self, batch_size: int) -> BatchTransition: """Sample a random batch of transitions and collate them into batched tensors.""" if not self.initialized: - raise RuntimeError( - "Cannot sample from an empty buffer. Add transitions first." - ) + raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.") batch_size = min(batch_size, self.size) # Random indices for sampling - create on the same device as storage - idx = torch.randint( - low=0, high=self.size, size=(batch_size,), device=self.storage_device - ) + idx = torch.randint(low=0, high=self.size, size=(batch_size,), device=self.storage_device) # Identify image keys that need augmentation - image_keys = ( - [k for k in self.states if k.startswith("observation.image")] - if self.use_drq - else [] - ) + image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else [] # Create batched state and next_state batch_state = {} @@ -358,13 +324,9 @@ class ReplayBuffer: # Split the augmented images back to their sources for i, key in enumerate(image_keys): # State images are at even indices (0, 2, 4...) - batch_state[key] = augmented_images[ - i * 2 * batch_size : (i * 2 + 1) * batch_size - ] + batch_state[key] = augmented_images[i * 2 * batch_size : (i * 2 + 1) * batch_size] # Next state images are at odd indices (1, 3, 5...) - batch_next_state[key] = augmented_images[ - (i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size - ] + batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size] # Sample other tensors batch_actions = self.actions[idx].to(self.device) @@ -434,16 +396,12 @@ class ReplayBuffer: ) # Convert dataset to transitions - list_transition = cls._lerobotdataset_to_transitions( - dataset=lerobot_dataset, state_keys=state_keys - ) + list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys) # Initialize the buffer with the first transition to set up storage tensors if list_transition: first_transition = list_transition[0] - first_state = { - k: v.to(device) for k, v in first_transition["state"].items() - } + first_state = {k: v.to(device) for k, v in first_transition["state"].items()} first_action = first_transition["action"].to(device) # Apply action mask/delta if needed @@ -541,9 +499,7 @@ class ReplayBuffer: # Convert transitions into episodes and frames episode_index = 0 - lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer( - episode_index=episode_index - ) + lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index=episode_index) frame_idx_in_episode = 0 for idx in range(self.size): @@ -557,12 +513,8 @@ class ReplayBuffer: # Fill action, reward, done frame_dict["action"] = self.actions[actual_idx].cpu() - frame_dict["next.reward"] = torch.tensor( - [self.rewards[actual_idx]], dtype=torch.float32 - ).cpu() - frame_dict["next.done"] = torch.tensor( - [self.dones[actual_idx]], dtype=torch.bool - ).cpu() + frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu() + frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu() # Add to the dataset's buffer lerobot_dataset.add_frame(frame_dict) @@ -619,9 +571,7 @@ class ReplayBuffer: A list of Transition dictionaries with the same length as `dataset`. """ if state_keys is None: - raise ValueError( - "State keys must be provided when converting LeRobotDataset to Transitions." - ) + raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.") transitions = [] num_frames = len(dataset) @@ -632,9 +582,7 @@ class ReplayBuffer: # If not, we need to infer it from episode boundaries if not has_done_key: - print( - "'next.done' key not found in dataset. Inferring from episode boundaries..." - ) + print("'next.done' key not found in dataset. Inferring from episode boundaries...") for i in tqdm(range(num_frames)): current_sample = dataset[i] @@ -886,8 +834,7 @@ if __name__ == "__main__": # We need to be careful because we don't know the original index # So we check if the increment is roughly 0.01 next_state_check = ( - abs(next_state_sig - state_sig - 0.01) < 1e-4 - or abs(next_state_sig - state_sig) < 1e-4 + abs(next_state_sig - state_sig - 0.01) < 1e-4 or abs(next_state_sig - state_sig) < 1e-4 ) # Count correct relationships @@ -901,17 +848,11 @@ if __name__ == "__main__": total_checks += 3 alignment_accuracy = 100.0 * correct_relationships / total_checks - print( - f"State-action-reward-next_state alignment accuracy: {alignment_accuracy:.2f}%" - ) + print(f"State-action-reward-next_state alignment accuracy: {alignment_accuracy:.2f}%") if alignment_accuracy > 99.0: - print( - "✅ All relationships verified! Buffer maintains correct temporal relationships." - ) + print("✅ All relationships verified! Buffer maintains correct temporal relationships.") else: - print( - "⚠️ Some relationships don't match expected patterns. Buffer may have alignment issues." - ) + print("⚠️ Some relationships don't match expected patterns. Buffer may have alignment issues.") # Print some debug information about failures print("\nDebug information for failed checks:") @@ -973,18 +914,14 @@ if __name__ == "__main__": # Verify consistency before and after conversion original_states = batch["state"]["observation.image"].mean().item() - reconverted_states = ( - reconverted_batch["state"]["observation.image"].mean().item() - ) + reconverted_states = reconverted_batch["state"]["observation.image"].mean().item() print(f"Original buffer state mean: {original_states:.4f}") print(f"Reconverted buffer state mean: {reconverted_states:.4f}") if abs(original_states - reconverted_states) < 1.0: print("Values are reasonably similar - conversion works as expected") else: - print( - "WARNING: Significant difference between original and reconverted values" - ) + print("WARNING: Significant difference between original and reconverted values") print("\nAll previous tests completed!") @@ -1093,15 +1030,11 @@ if __name__ == "__main__": all_indices = torch.arange(sequential_batch_size, device=test_buffer.storage_device) # Get state tensors - batch_state = { - "value": test_buffer.states["value"][all_indices].to(test_buffer.device) - } + batch_state = {"value": test_buffer.states["value"][all_indices].to(test_buffer.device)} # Get next_state using memory-optimized approach (simply index+1) next_indices = (all_indices + 1) % test_buffer.capacity - batch_next_state = { - "value": test_buffer.states["value"][next_indices].to(test_buffer.device) - } + batch_next_state = {"value": test_buffer.states["value"][next_indices].to(test_buffer.device)} # Get other tensors batch_dones = test_buffer.dones[all_indices].to(test_buffer.device) @@ -1121,9 +1054,7 @@ if __name__ == "__main__": print("- We always use the next state in the buffer (index+1) as next_state") print("- For terminal states, this means using the first state of the next episode") print("- This is a common tradeoff in RL implementations for memory efficiency") - print( - "- Since we track done flags, the algorithm can handle these transitions correctly" - ) + print("- Since we track done flags, the algorithm can handle these transitions correctly") # Test random sampling print("\nVerifying random sampling with simplified memory optimization...") @@ -1137,23 +1068,19 @@ if __name__ == "__main__": # Print a few samples print("Random samples - State, Next State, Done (First 10):") for i in range(10): - print( - f" {random_state_values[i]:.1f} → {random_next_values[i]:.1f}, Done: {random_done_flags[i]}" - ) + print(f" {random_state_values[i]:.1f} → {random_next_values[i]:.1f}, Done: {random_done_flags[i]}") # Calculate memory savings # Assume optimized_buffer and standard_buffer have already been initialized and filled std_mem = ( sum( - standard_buffer.states[key].nelement() - * standard_buffer.states[key].element_size() + standard_buffer.states[key].nelement() * standard_buffer.states[key].element_size() for key in standard_buffer.states ) * 2 ) opt_mem = sum( - optimized_buffer.states[key].nelement() - * optimized_buffer.states[key].element_size() + optimized_buffer.states[key].nelement() * optimized_buffer.states[key].element_size() for key in optimized_buffer.states ) diff --git a/lerobot/scripts/server/crop_dataset_roi.py b/lerobot/scripts/server/crop_dataset_roi.py index d6c3dd51..cad3419a 100644 --- a/lerobot/scripts/server/crop_dataset_roi.py +++ b/lerobot/scripts/server/crop_dataset_roi.py @@ -225,9 +225,7 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset( if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Crop rectangular ROIs from a LeRobot dataset." - ) + parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.") parser.add_argument( "--repo-id", type=str, @@ -249,9 +247,7 @@ if __name__ == "__main__": args = parser.parse_args() local_files_only = args.root is not None - dataset = LeRobotDataset( - repo_id=args.repo_id, root=args.root, local_files_only=local_files_only - ) + dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=local_files_only) images = get_image_from_lerobot_dataset(dataset) images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()} diff --git a/lerobot/scripts/server/end_effector_control_utils.py b/lerobot/scripts/server/end_effector_control_utils.py index b7522d29..d5b217e4 100644 --- a/lerobot/scripts/server/end_effector_control_utils.py +++ b/lerobot/scripts/server/end_effector_control_utils.py @@ -1,13 +1,14 @@ -from lerobot.common.robot_devices.robots.factory import make_robot -from lerobot.common.utils.utils import init_hydra_config -from lerobot.common.robot_devices.utils import busy_wait -from lerobot.scripts.server.kinematics import RobotKinematics +import argparse import logging import time -import torch -import numpy as np -import argparse +import numpy as np +import torch + +from lerobot.common.robot_devices.robots.factory import make_robot +from lerobot.common.robot_devices.utils import busy_wait +from lerobot.common.utils.utils import init_hydra_config +from lerobot.scripts.server.kinematics import RobotKinematics logging.basicConfig(level=logging.INFO) @@ -187,9 +188,7 @@ class KeyboardController(InputController): class GamepadController(InputController): """Generate motion deltas from gamepad input.""" - def __init__( - self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01, deadzone=0.1 - ): + def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01, deadzone=0.1): super().__init__(x_step_size, y_step_size, z_step_size) self.deadzone = deadzone self.joystick = None @@ -203,9 +202,7 @@ class GamepadController(InputController): pygame.joystick.init() if pygame.joystick.get_count() == 0: - logging.error( - "No gamepad detected. Please connect a gamepad and try again." - ) + logging.error("No gamepad detected. Please connect a gamepad and try again.") self.running = False return @@ -338,18 +335,12 @@ class GamepadControllerHID(InputController): devices = hid.enumerate() for device in devices: - if ( - device["vendor_id"] == self.vendor_id - and device["product_id"] == self.product_id - ): - logging.info( - f"Found gamepad: {device.get('product_string', 'Unknown')}" - ) + if device["vendor_id"] == self.vendor_id and device["product_id"] == self.product_id: + logging.info(f"Found gamepad: {device.get('product_string', 'Unknown')}") return device logging.error( - f"No gamepad with vendor ID 0x{self.vendor_id:04X} and " - f"product ID 0x{self.product_id:04X} found" + f"No gamepad with vendor ID 0x{self.vendor_id:04X} and product ID 0x{self.product_id:04X} found" ) return None @@ -381,9 +372,7 @@ class GamepadControllerHID(InputController): except OSError as e: logging.error(f"Error opening gamepad: {e}") - logging.error( - "You might need to run this with sudo/admin privileges on some systems" - ) + logging.error("You might need to run this with sudo/admin privileges on some systems") self.running = False def stop(self): @@ -421,12 +410,8 @@ class GamepadControllerHID(InputController): # Apply deadzone self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y - self.right_x = ( - 0 if abs(self.right_x) < self.deadzone else self.right_x - ) - self.right_y = ( - 0 if abs(self.right_y) < self.deadzone else self.right_y - ) + self.right_x = 0 if abs(self.right_x) < self.deadzone else self.right_x + self.right_y = 0 if abs(self.right_y) < self.deadzone else self.right_y # Parse button states (byte 5 in the Logitech RumblePad 2) buttons = data[5] @@ -493,9 +478,7 @@ def test_inverse_kinematics(robot, fps=10): joint_positions = obs["observation.state"].cpu().numpy() ee_pos = RobotKinematics.fk_gripper_tip(joint_positions) desired_ee_pos = ee_pos - target_joint_state = RobotKinematics.ik( - joint_positions, desired_ee_pos, position_only=True - ) + target_joint_state = RobotKinematics.ik(joint_positions, desired_ee_pos, position_only=True) robot.send_action(torch.from_numpy(target_joint_state)) logging.info(f"Target Joint State: {target_joint_state}") busy_wait(1 / fps - (time.perf_counter() - loop_start_time)) @@ -573,17 +556,13 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10): robot.send_action(torch.from_numpy(target_joint_state)) # Logging - logging.info( - f"Current EE: {current_ee_pos[:3, 3]}, Desired EE: {desired_ee_pos[:3, 3]}" - ) + logging.info(f"Current EE: {current_ee_pos[:3, 3]}, Desired EE: {desired_ee_pos[:3, 3]}") logging.info(f"Delta EE: {ee_delta[:3, 3]}") busy_wait(1 / fps - (time.perf_counter() - loop_start_time)) -def teleoperate_delta_inverse_kinematics( - robot, controller, fps=10, bounds=None, fk_func=None -): +def teleoperate_delta_inverse_kinematics(robot, controller, fps=10, bounds=None, fk_func=None): """ Control a robot using delta end-effector movements from any input controller. @@ -597,9 +576,7 @@ def teleoperate_delta_inverse_kinematics( if fk_func is None: fk_func = RobotKinematics.fk_gripper_tip - logging.info( - f"Testing Delta End-Effector Control with {controller.__class__.__name__}" - ) + logging.info(f"Testing Delta End-Effector Control with {controller.__class__.__name__}") # Initial position capture obs = robot.capture_observation() @@ -631,9 +608,7 @@ def teleoperate_delta_inverse_kinematics( # Apply bounds if provided if bounds is not None: - desired_ee_pos[:3, 3] = np.clip( - desired_ee_pos[:3, 3], bounds["min"], bounds["max"] - ) + desired_ee_pos[:3, 3] = np.clip(desired_ee_pos[:3, 3], bounds["min"], bounds["max"]) # Only send commands if there's actual movement if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]): @@ -684,14 +659,10 @@ def teleoperate_gym_env(env, controller, fps: int = 30): if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]): # Step the environment - pass action as a tensor with intervention flag action_tensor = torch.from_numpy(action.astype(np.float32)) - obs, reward, terminated, truncated, info = env.step( - (action_tensor, False) - ) + obs, reward, terminated, truncated, info = env.step((action_tensor, False)) # Log information - logging.info( - f"Action: [{delta_x:.4f}, {delta_y:.4f}, {delta_z:.4f}]" - ) + logging.info(f"Action: [{delta_x:.4f}, {delta_y:.4f}, {delta_z:.4f}]") logging.info(f"Reward: {reward}") # Reset if episode ended @@ -761,20 +732,14 @@ if __name__ == "__main__": # Determine controller type based on mode prefix controller = None if args.mode.startswith("keyboard"): - controller = KeyboardController( - x_step_size=0.01, y_step_size=0.01, z_step_size=0.05 - ) + controller = KeyboardController(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05) elif args.mode.startswith("gamepad"): - controller = GamepadController( - x_step_size=0.02, y_step_size=0.02, z_step_size=0.05 - ) + controller = GamepadController(x_step_size=0.02, y_step_size=0.02, z_step_size=0.05) # Handle mode categories if args.mode in ["keyboard", "gamepad"]: # Direct robot control modes - teleoperate_delta_inverse_kinematics( - robot, controller, bounds=bounds, fps=10 - ) + teleoperate_delta_inverse_kinematics(robot, controller, bounds=bounds, fps=10) elif args.mode in ["keyboard_gym", "gamepad_gym"]: # Gym environment control modes diff --git a/lerobot/scripts/server/find_joint_limits.py b/lerobot/scripts/server/find_joint_limits.py index 7834f821..f8891ba7 100644 --- a/lerobot/scripts/server/find_joint_limits.py +++ b/lerobot/scripts/server/find_joint_limits.py @@ -32,9 +32,7 @@ def find_joint_bounds( if display_cameras and not is_headless(): image_keys = [key for key in observation if "image" in key] for key in image_keys: - cv2.imshow( - key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR) - ) + cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) cv2.waitKey(1) if time.perf_counter() - start_episode_t > control_time_s: @@ -69,9 +67,7 @@ def find_ee_bounds( if display_cameras and not is_headless(): image_keys = [key for key in observation if "image" in key] for key in image_keys: - cv2.imshow( - key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR) - ) + cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) cv2.waitKey(1) if time.perf_counter() - start_episode_t > control_time_s: diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 2dd7bdbb..63a0fbc9 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -1,10 +1,10 @@ import argparse -import sys - import logging +import sys import time from threading import Lock from typing import Annotated, Any, Dict, Tuple + import gymnasium as gym import numpy as np import torch @@ -18,7 +18,6 @@ from lerobot.common.robot_devices.control_utils import ( ) from lerobot.common.robot_devices.robots.factory import make_robot from lerobot.common.utils.utils import init_hydra_config, log_say - from lerobot.scripts.server.kinematics import RobotKinematics logging.basicConfig(level=logging.INFO) @@ -67,9 +66,7 @@ class HILSerlRobotEnv(gym.Env): if not self.robot.is_connected: self.robot.connect() - self.initial_follower_position = robot.follower_arms["main"].read( - "Present_Position" - ) + self.initial_follower_position = robot.follower_arms["main"].read("Present_Position") # Episode tracking. self.current_step = 0 @@ -77,9 +74,7 @@ class HILSerlRobotEnv(gym.Env): self.delta = delta self.use_delta_action_space = use_delta_action_space - self.current_joint_positions = self.robot.follower_arms["main"].read( - "Present_Position" - ) + self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position") # Retrieve the size of the joint position interval bound. self.relative_bounds_size = ( @@ -92,9 +87,7 @@ class HILSerlRobotEnv(gym.Env): ) self.robot.config.max_relative_target = ( - self.relative_bounds_size.float() - if self.relative_bounds_size is not None - else None + self.relative_bounds_size.float() if self.relative_bounds_size is not None else None ) # Dynamically configure the observation and action spaces. @@ -119,9 +112,7 @@ class HILSerlRobotEnv(gym.Env): # Define observation spaces for images and other states. image_keys = [key for key in example_obs if "image" in key] observation_spaces = { - key: gym.spaces.Box( - low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8 - ) + key: gym.spaces.Box(low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8) for key in image_keys } observation_spaces["observation.state"] = gym.spaces.Box( @@ -172,9 +163,7 @@ class HILSerlRobotEnv(gym.Env): ), ) - def reset( - self, seed=None, options=None - ) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]: + def reset(self, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]: """ Reset the environment to its initial state. This method resets the step counter and clears any episodic data. @@ -231,35 +220,25 @@ class HILSerlRobotEnv(gym.Env): """ policy_action, intervention_bool = action teleop_action = None - self.current_joint_positions = self.robot.follower_arms["main"].read( - "Present_Position" - ) + self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position") if isinstance(policy_action, torch.Tensor): policy_action = policy_action.cpu().numpy() - policy_action = np.clip( - policy_action, self.action_space[0].low, self.action_space[0].high - ) + policy_action = np.clip(policy_action, self.action_space[0].low, self.action_space[0].high) if not intervention_bool: if self.use_delta_action_space: - target_joint_positions = ( - self.current_joint_positions + self.delta * policy_action - ) + target_joint_positions = self.current_joint_positions + self.delta * policy_action else: target_joint_positions = policy_action self.robot.send_action(torch.from_numpy(target_joint_positions)) observation = self.robot.capture_observation() else: observation, teleop_action = self.robot.teleop_step(record_data=True) - teleop_action = teleop_action[ - "action" - ] # Convert tensor to appropriate format + teleop_action = teleop_action["action"] # Convert tensor to appropriate format # When applying the delta action space, convert teleop absolute values to relative differences. if self.use_delta_action_space: - teleop_action = ( - teleop_action - self.current_joint_positions - ) / self.delta + teleop_action = (teleop_action - self.current_joint_positions) / self.delta if self.relative_bounds_size is not None and ( torch.any(teleop_action < -self.relative_bounds_size) and torch.any(teleop_action > self.relative_bounds_size) @@ -333,12 +312,8 @@ class AddJointVelocityToObservation(gym.ObservationWrapper): self.last_joint_positions = np.zeros(old_shape) - new_low = np.concatenate( - [old_low, np.ones_like(old_low) * -joint_velocity_limits] - ) - new_high = np.concatenate( - [old_high, np.ones_like(old_high) * joint_velocity_limits] - ) + new_low = np.concatenate([old_low, np.ones_like(old_low) * -joint_velocity_limits]) + new_high = np.concatenate([old_high, np.ones_like(old_high) * joint_velocity_limits]) new_shape = (old_shape[0] * 2,) @@ -352,9 +327,7 @@ class AddJointVelocityToObservation(gym.ObservationWrapper): self.dt = 1.0 / fps def observation(self, observation): - joint_velocities = ( - observation["observation.state"] - self.last_joint_positions - ) / self.dt + joint_velocities = (observation["observation.state"] - self.last_joint_positions) / self.dt self.last_joint_positions = observation["observation.state"].clone() observation["observation.state"] = torch.cat( [observation["observation.state"], joint_velocities], dim=-1 @@ -439,9 +412,7 @@ class JointMaskingActionSpace(gym.Wrapper): raise ValueError("Mask length must match action space dimensions") low = env.action_space.low[self.active_dims] high = env.action_space.high[self.active_dims] - self.action_space = gym.spaces.Box( - low=low, high=high, dtype=env.action_space.dtype - ) + self.action_space = gym.spaces.Box(low=low, high=high, dtype=env.action_space.dtype) if isinstance(env.action_space, gym.spaces.Tuple): if len(mask) != env.action_space[0].shape[0]: @@ -449,12 +420,8 @@ class JointMaskingActionSpace(gym.Wrapper): low = env.action_space[0].low[self.active_dims] high = env.action_space[0].high[self.active_dims] - action_space_masked = gym.spaces.Box( - low=low, high=high, dtype=env.action_space[0].dtype - ) - self.action_space = gym.spaces.Tuple( - (action_space_masked, env.action_space[1]) - ) + action_space_masked = gym.spaces.Box(low=low, high=high, dtype=env.action_space[0].dtype) + self.action_space = gym.spaces.Tuple((action_space_masked, env.action_space[1])) # Create new action space with masked dimensions def action(self, action): @@ -473,18 +440,14 @@ class JointMaskingActionSpace(gym.Wrapper): # Extract the masked component from the tuple. masked_action = action[0] if isinstance(action, tuple) else action # Create a full action for the Box element. - full_box_action = np.zeros( - self.env.action_space[0].shape, dtype=self.env.action_space[0].dtype - ) + full_box_action = np.zeros(self.env.action_space[0].shape, dtype=self.env.action_space[0].dtype) full_box_action[self.active_dims] = masked_action # Return a tuple with the reconstructed Box action and the unchanged remainder. return (full_box_action, action[1]) else: # For Box action spaces. masked_action = action if not isinstance(action, tuple) else action[0] - full_action = np.zeros( - self.env.action_space.shape, dtype=self.env.action_space.dtype - ) + full_action = np.zeros(self.env.action_space.shape, dtype=self.env.action_space.dtype) full_action[self.active_dims] = masked_action return full_action @@ -493,13 +456,9 @@ class JointMaskingActionSpace(gym.Wrapper): obs, reward, terminated, truncated, info = self.env.step(action) if "action_intervention" in info and info["action_intervention"] is not None: if info["action_intervention"].dim() == 1: - info["action_intervention"] = info["action_intervention"][ - self.active_dims - ] + info["action_intervention"] = info["action_intervention"][self.active_dims] else: - info["action_intervention"] = info["action_intervention"][ - :, self.active_dims - ] + info["action_intervention"] = info["action_intervention"][:, self.active_dims] return obs, reward, terminated, truncated, info @@ -555,9 +514,7 @@ class ImageCropResizeWrapper(gym.Wrapper): for key in crop_params_dict: top, left, height, width = crop_params_dict[key] new_shape = (top + height, left + width) - self.observation_space[key] = gym.spaces.Box( - low=0, high=255, shape=new_shape - ) + self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape) self.resize_size = resize_size if self.resize_size is None: @@ -583,9 +540,7 @@ class ImageCropResizeWrapper(gym.Wrapper): ) # Check for NaNs before processing if torch.isnan(obs[k]).any(): - logging.error( - f"NaN values detected in observation {k} before crop and resize" - ) + logging.error(f"NaN values detected in observation {k} before crop and resize") if device == torch.device("mps:0"): obs[k] = obs[k].cpu() @@ -595,9 +550,7 @@ class ImageCropResizeWrapper(gym.Wrapper): # Check for NaNs after processing if torch.isnan(obs[k]).any(): - logging.error( - f"NaN values detected in observation {k} after crop and resize" - ) + logging.error(f"NaN values detected in observation {k} after crop and resize") obs[k] = obs[k].to(device) @@ -627,14 +580,10 @@ class ConvertToLeRobotObservation(gym.ObservationWrapper): observation = preprocess_observation(observation) observation = { - key: observation[key].to( - self.device, non_blocking=self.device.type == "cuda" - ) + key: observation[key].to(self.device, non_blocking=self.device.type == "cuda") for key in observation } - observation = { - k: torch.tensor(v, device=self.device) for k, v in observation.items() - } + observation = {k: torch.tensor(v, device=self.device) for k, v in observation.items()} return observation @@ -686,26 +635,16 @@ class KeyboardInterfaceWrapper(gym.Wrapper): play_sounds=True, ) return - if ( - self.events["pause_policy"] - and not self.events["human_intervention_step"] - ): + if self.events["pause_policy"] and not self.events["human_intervention_step"]: self.events["human_intervention_step"] = True print("Space key pressed. Human intervention starting.") - log_say( - "Starting human intervention.", play_sounds=True - ) + log_say("Starting human intervention.", play_sounds=True) return - if ( - self.events["pause_policy"] - and self.events["human_intervention_step"] - ): + if self.events["pause_policy"] and self.events["human_intervention_step"]: self.events["pause_policy"] = False self.events["human_intervention_step"] = False print("Space key pressed for a third time.") - log_say( - "Continuing with policy actions.", play_sounds=True - ) + log_say("Continuing with policy actions.", play_sounds=True) return except Exception as e: print(f"Error handling key press: {e}") @@ -713,9 +652,7 @@ class KeyboardInterfaceWrapper(gym.Wrapper): self.listener = keyboard.Listener(on_press=on_press) self.listener.start() except ImportError: - logging.warning( - "Could not import pynput. Keyboard interface will not be available." - ) + logging.warning("Could not import pynput. Keyboard interface will not be available.") self.listener = None def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]: @@ -742,9 +679,7 @@ class KeyboardInterfaceWrapper(gym.Wrapper): time.sleep(0.1) # Check more frequently if desired # Execute the step in the underlying environment - obs, reward, terminated, truncated, info = self.env.step( - (policy_action, is_intervention) - ) + obs, reward, terminated, truncated, info = self.env.step((policy_action, is_intervention)) # Override reward and termination if episode success event triggered with self.event_lock: @@ -807,9 +742,7 @@ class BatchCompitableWrapper(gym.ObservationWrapper): def __init__(self, env): super().__init__(env) - def observation( - self, observation: dict[str, torch.Tensor] - ) -> dict[str, torch.Tensor]: + def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: for key in observation: if "image" in key and observation[key].dim() == 3: observation[key] = observation[key].unsqueeze(0) @@ -844,9 +777,7 @@ class EEActionWrapper(gym.ActionWrapper): dtype=np.float32, ) if isinstance(self.action_space, gym.spaces.Tuple): - self.action_space = gym.spaces.Tuple( - (ee_action_space, self.action_space[1]) - ) + self.action_space = gym.spaces.Tuple((ee_action_space, self.action_space[1])) else: self.action_space = ee_action_space @@ -858,9 +789,7 @@ class EEActionWrapper(gym.ActionWrapper): if isinstance(action, tuple): action, _ = action - current_joint_pos = self.unwrapped.robot.follower_arms["main"].read( - "Present_Position" - ) + current_joint_pos = self.unwrapped.robot.follower_arms["main"].read("Present_Position") current_ee_pos = self.fk_function(current_joint_pos) if isinstance(action, torch.Tensor): action = action.cpu().numpy() @@ -898,9 +827,7 @@ class EEObservationWrapper(gym.ObservationWrapper): self.fk_function = self.kinematics.fk_gripper_tip def observation(self, observation): - current_joint_pos = self.unwrapped.robot.follower_arms["main"].read( - "Present_Position" - ) + current_joint_pos = self.unwrapped.robot.follower_arms["main"].read("Present_Position") current_ee_pos = self.fk_function(current_joint_pos) observation["observation.state"] = torch.cat( [ @@ -944,8 +871,8 @@ class GamepadControlWrapper(gym.Wrapper): """ super().__init__(env) from lerobot.scripts.server.end_effector_control_utils import ( - GamepadControllerHID, GamepadController, + GamepadControllerHID, ) # use HidApi for macos @@ -1027,9 +954,7 @@ class GamepadControlWrapper(gym.Wrapper): # Update episode ending state if requested if terminate_episode: - logging.info( - f"Episode manually ended: {'SUCCESS' if success else 'FAILURE'}" - ) + logging.info(f"Episode manually ended: {'SUCCESS' if success else 'FAILURE'}") # Only override the action if gamepad is active if is_intervention: @@ -1054,9 +979,7 @@ class GamepadControlWrapper(gym.Wrapper): logging.info("Episode ended successfully with reward 1.0") info["is_intervention"] = is_intervention - action_intervention = ( - final_action[0] if isinstance(final_action, Tuple) else final_action - ) + action_intervention = final_action[0] if isinstance(final_action, Tuple) else final_action if isinstance(action_intervention, np.ndarray): action_intervention = torch.from_numpy(action_intervention) info["action_intervention"] = action_intervention @@ -1087,9 +1010,7 @@ class GamepadControlWrapper(gym.Wrapper): class ActionScaleWrapper(gym.ActionWrapper): def __init__(self, env, ee_action_space_params=None): super().__init__(env) - assert ee_action_space_params is not None, ( - "TODO: method implemented for ee action space only so far" - ) + assert ee_action_space_params is not None, "TODO: method implemented for ee action space only so far" self.scale_vector = np.array( [ [ @@ -1148,9 +1069,7 @@ def make_robot_env( if cfg.env.wrapper.add_joint_velocity_to_observation: env = AddJointVelocityToObservation(env=env, fps=cfg.fps) if cfg.env.wrapper.add_ee_pose_to_observation: - env = EEObservationWrapper( - env=env, ee_pose_limits=cfg.env.wrapper.ee_action_space_params.bounds - ) + env = EEObservationWrapper(env=env, ee_pose_limits=cfg.env.wrapper.ee_action_space_params.bounds) env = ConvertToLeRobotObservation(env=env, device=cfg.env.device) @@ -1163,13 +1082,9 @@ def make_robot_env( # Add reward computation and control wrappers # env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) - env = TimeLimitWrapper( - env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps - ) + env = TimeLimitWrapper(env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps) if cfg.env.wrapper.ee_action_space_params is not None: - env = EEActionWrapper( - env=env, ee_action_space_params=cfg.env.wrapper.ee_action_space_params - ) + env = EEActionWrapper(env=env, ee_action_space_params=cfg.env.wrapper.ee_action_space_params) if ( cfg.env.wrapper.ee_action_space_params is not None and cfg.env.wrapper.ee_action_space_params.use_gamepad @@ -1193,9 +1108,7 @@ def make_robot_env( cfg.env.wrapper.ee_action_space_params is None and cfg.env.wrapper.joint_masking_action_space is not None ): - env = JointMaskingActionSpace( - env=env, mask=cfg.env.wrapper.joint_masking_action_space - ) + env = JointMaskingActionSpace(env=env, mask=cfg.env.wrapper.joint_masking_action_space) env = BatchCompitableWrapper(env=env) return env @@ -1216,9 +1129,7 @@ def get_classifier(pretrained_path, config_path, device="mps"): cfg = init_hydra_config(config_path) classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg) - classifier_config.num_cameras = len( - cfg.training.image_keys - ) # TODO automate these paths + classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths model = Classifier(classifier_config) model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict()) model = model.to(device) @@ -1317,9 +1228,7 @@ def record_dataset( # For teleop, get action from intervention if policy is None: - action = { - "action": info["action_intervention"].cpu().squeeze(0).float() - } + action = {"action": info["action_intervention"].cpu().squeeze(0).float()} # Process observation for dataset obs = {k: v.cpu().squeeze(0).float() for k, v in obs.items()} @@ -1357,9 +1266,7 @@ def replay_episode(env, repo_id, root=None, episode=0): from lerobot.common.datasets.lerobot_dataset import LeRobotDataset local_files_only = root is not None - dataset = LeRobotDataset( - repo_id, root=root, episodes=[episode], local_files_only=local_files_only - ) + dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only) env.reset() actions = dataset.hf_dataset.select_columns("action") @@ -1414,9 +1321,7 @@ if __name__ == "__main__": default=None, help="Path to a yaml config file that is necessary to build the reward classifier model.", ) - parser.add_argument( - "--env-path", type=str, default=None, help="Path to the env yaml file" - ) + parser.add_argument("--env-path", type=str, default=None, help="Path to the env yaml file") parser.add_argument( "--env-overrides", type=str, @@ -1441,12 +1346,8 @@ if __name__ == "__main__": default=None, help="Repo ID of the episode to replay", ) - parser.add_argument( - "--dataset-root", type=str, default=None, help="Root of the dataset to replay" - ) - parser.add_argument( - "--replay-episode", type=int, default=0, help="Episode to replay" - ) + parser.add_argument("--dataset-root", type=str, default=None, help="Root of the dataset to replay") + parser.add_argument("--replay-episode", type=int, default=0, help="Episode to replay") parser.add_argument( "--record-repo-id", type=str, @@ -1534,9 +1435,7 @@ if __name__ == "__main__": smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action # Execute the step: wrap the NumPy action in a torch tensor. - obs, reward, terminated, truncated, info = env.step( - (torch.from_numpy(smoothed_action), False) - ) + obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False)) if terminated or truncated: sucesses.append(reward) env.reset() diff --git a/lerobot/scripts/server/kinematics.py b/lerobot/scripts/server/kinematics.py index 6622fe76..fb16548a 100644 --- a/lerobot/scripts/server/kinematics.py +++ b/lerobot/scripts/server/kinematics.py @@ -23,11 +23,7 @@ def screw_axis_to_transform(S, theta): elif np.linalg.norm(S_w) == 1: # Rotation and translation w_hat = skew_symmetric(S_w) R = np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat - t = ( - np.eye(3) * theta - + (1 - np.cos(theta)) * w_hat - + (theta - np.sin(theta)) * w_hat @ w_hat - ) @ S_v + t = (np.eye(3) * theta + (1 - np.cos(theta)) * w_hat + (theta - np.sin(theta)) * w_hat @ w_hat) @ S_v T = np.eye(4) T[:3, :3] = R T[:3, 3] = t @@ -189,9 +185,7 @@ class RobotKinematics: # Wrist # Screw axis of wrist frame wrt base frame - self.S_BR = np.array( - [0, 1, 0, -self.measurements["wrist"][2], 0, self.measurements["wrist"][0]] - ) + self.S_BR = np.array([0, 1, 0, -self.measurements["wrist"][2], 0, self.measurements["wrist"][0]]) # 0-position origin to centroid transform self.X_RoRc = self._create_translation_matrix(x=0.0035, y=-0.002) @@ -284,12 +278,7 @@ class RobotKinematics: def fk_shoulder(self, robot_pos_deg): """Forward kinematics for the shoulder frame.""" robot_pos_rad = robot_pos_deg / 180 * np.pi - return ( - self.X_WoBo - @ screw_axis_to_transform(self.S_BS, robot_pos_rad[0]) - @ self.X_SoSc - @ self.X_BS - ) + return self.X_WoBo @ screw_axis_to_transform(self.S_BS, robot_pos_rad[0]) @ self.X_SoSc @ self.X_BS def fk_humerus(self, robot_pos_deg): """Forward kinematics for the humerus frame.""" @@ -403,15 +392,12 @@ class RobotKinematics: delta *= 0 delta[el_ix] = eps / 2 Sdot = ( - fk_func(robot_pos_deg[:-1] + delta)[:3, 3] - - fk_func(robot_pos_deg[:-1] - delta)[:3, 3] + fk_func(robot_pos_deg[:-1] + delta)[:3, 3] - fk_func(robot_pos_deg[:-1] - delta)[:3, 3] ) / eps jac[:, el_ix] = Sdot return jac - def ik( - self, current_joint_state, desired_ee_pose, position_only=True, fk_func=None - ): + def ik(self, current_joint_state, desired_ee_pose, position_only=True, fk_func=None): """Inverse kinematics using gradient descent. Args: @@ -457,9 +443,7 @@ if __name__ == "__main__": # Test 1: Forward kinematics consistency print("Test 1: Forward kinematics consistency") - test_angles = np.array( - [30, 45, -30, 20, 10, 0] - ) # Example joint angles in degrees + test_angles = np.array([30, 45, -30, 20, 10, 0]) # Example joint angles in degrees # Calculate FK for different joints shoulder_pose = robot.fk_shoulder(test_angles) @@ -480,13 +464,9 @@ if __name__ == "__main__": ] # Check if distances generally increase along the chain - is_consistent = all( - distances[i] <= distances[i + 1] for i in range(len(distances) - 1) - ) + is_consistent = all(distances[i] <= distances[i + 1] for i in range(len(distances) - 1)) print(f" Pose distances from origin: {[round(d, 3) for d in distances]}") - print( - f" Kinematic chain consistency: {'PASSED' if is_consistent else 'FAILED'}" - ) + print(f" Kinematic chain consistency: {'PASSED' if is_consistent else 'FAILED'}") # Test 2: Jacobian computation print("Test 2: Jacobian computation") @@ -498,9 +478,7 @@ if __name__ == "__main__": pos_jacobian_shape_ok = positional_jacobian.shape == (3, 5) print(f" Jacobian shape: {'PASSED' if jacobian_shape_ok else 'FAILED'}") - print( - f" Positional Jacobian shape: {'PASSED' if pos_jacobian_shape_ok else 'FAILED'}" - ) + print(f" Positional Jacobian shape: {'PASSED' if pos_jacobian_shape_ok else 'FAILED'}") # Test 3: Inverse kinematics print("Test 3: Inverse kinematics (position only)") diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index 82765a28..7d7db5cd 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -17,15 +17,8 @@ import logging import shutil import time -from pprint import pformat from concurrent.futures import ThreadPoolExecutor - -# from torch.multiprocessing import Event, Queue, Process -# from threading import Event, Thread -# from torch.multiprocessing import Queue, Event -from torch.multiprocessing import Queue - -from lerobot.scripts.server.utils import setup_process_handlers +from pprint import pformat import grpc @@ -37,6 +30,11 @@ from deepdiff import DeepDiff from omegaconf import DictConfig, OmegaConf from termcolor import colored from torch import nn + +# from torch.multiprocessing import Event, Queue, Process +# from threading import Event, Thread +# from torch.multiprocessing import Queue, Event +from torch.multiprocessing import Queue from torch.optim.optimizer import Optimizer from lerobot.common.datasets.factory import make_dataset @@ -55,18 +53,17 @@ from lerobot.common.utils.utils import ( set_global_random_state, set_global_seed, ) - +from lerobot.scripts.server import learner_service from lerobot.scripts.server.buffer import ( ReplayBuffer, - concatenate_batch_transitions, - move_transition_to_device, - move_state_dict_to_device, - bytes_to_transitions, - state_to_bytes, bytes_to_python_object, + bytes_to_transitions, + concatenate_batch_transitions, + move_state_dict_to_device, + move_transition_to_device, + state_to_bytes, ) - -from lerobot.scripts.server import learner_service +from lerobot.scripts.server.utils import setup_process_handlers def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig: @@ -81,13 +78,9 @@ def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig: # if resume == True checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir) if not checkpoint_dir.exists(): - raise RuntimeError( - f"No model checkpoint found in {checkpoint_dir} for resume=True" - ) + raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True") - checkpoint_cfg_path = str( - Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml" - ) + checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml") logging.info( colored( "Resume=True detected, resuming previous run", @@ -136,9 +129,7 @@ def load_training_state( def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None: - num_learnable_params = sum( - p.numel() for p in policy.parameters() if p.requires_grad - ) + num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) log_output_dir(out_dir) @@ -210,22 +201,15 @@ def initialize_offline_replay_buffer( def get_observation_features( policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor ) -> tuple[torch.Tensor | None, torch.Tensor | None]: - if ( - policy.config.vision_encoder_name is None - or not policy.config.freeze_vision_encoder - ): + if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder: return None, None with torch.no_grad(): observation_features = ( - policy.actor.encoder(observations) - if policy.actor.encoder is not None - else None + policy.actor.encoder(observations) if policy.actor.encoder is not None else None ) next_observation_features = ( - policy.actor.encoder(next_observations) - if policy.actor.encoder is not None - else None + policy.actor.encoder(next_observations) if policy.actor.encoder is not None else None ) return observation_features, next_observation_features @@ -452,9 +436,7 @@ def add_actor_information_and_train( # dataset_stats=offline_dataset.meta.stats if not cfg.resume else None, # Hack: But if we do online traning, we do not need dataset_stats dataset_stats=None, - pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) - if cfg.resume - else None, + pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None, ) # Update the policy config with the grad_clip_norm value from training config if it exists @@ -469,9 +451,7 @@ def add_actor_information_and_train( last_time_policy_pushed = time.time() optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy) - resume_optimization_step, resume_interaction_step = load_training_state( - cfg, logger, optimizers - ) + resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers) log_training_info(cfg, out_dir, policy) @@ -483,9 +463,7 @@ def add_actor_information_and_train( active_action_dims = None if cfg.env.wrapper.joint_masking_action_space is not None: active_action_dims = [ - i - for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) - if mask + i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask ] offline_replay_buffer = initialize_offline_replay_buffer( cfg=cfg, @@ -502,12 +480,8 @@ def add_actor_information_and_train( time.time() logging.info("Starting learner thread") interaction_message, transition = None, None - optimization_step = ( - resume_optimization_step if resume_optimization_step is not None else 0 - ) - interaction_step_shift = ( - resume_interaction_step if resume_interaction_step is not None else 0 - ) + optimization_step = resume_optimization_step if resume_optimization_step is not None else 0 + interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0 # Extract variables from cfg online_step_before_learning = cfg.training.online_step_before_learning @@ -519,9 +493,7 @@ def add_actor_information_and_train( device = cfg.device storage_device = cfg.training.storage_device policy_update_freq = cfg.training.policy_update_freq - policy_parameters_push_frequency = ( - cfg.actor_learner_config.policy_parameters_push_frequency - ) + policy_parameters_push_frequency = cfg.actor_learner_config.policy_parameters_push_frequency save_checkpoint = cfg.training.save_checkpoint online_steps = cfg.training.online_steps @@ -544,9 +516,9 @@ def add_actor_information_and_train( continue replay_buffer.add(**transition) - if cfg.dataset_repo_id is not None and transition.get( - "complementary_info", {} - ).get("is_intervention"): + if cfg.dataset_repo_id is not None and transition.get("complementary_info", {}).get( + "is_intervention" + ): offline_replay_buffer.add(**transition) logging.debug("[LEARNER] Received transitions") @@ -556,9 +528,7 @@ def add_actor_information_and_train( interaction_message = bytes_to_python_object(interaction_message) # If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging interaction_message["Interaction step"] += interaction_step_shift - logger.log_dict( - interaction_message, mode="train", custom_step_key="Interaction step" - ) + logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step") logging.debug("[LEARNER] Received interactions") @@ -579,9 +549,7 @@ def add_actor_information_and_train( observations = batch["state"] next_observations = batch["next_state"] done = batch["done"] - check_nan_in_transition( - observations=observations, actions=actions, next_state=next_observations - ) + check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) observation_features, next_observation_features = get_observation_features( policy, observations, next_observations @@ -619,9 +587,7 @@ def add_actor_information_and_train( next_observations = batch["next_state"] done = batch["done"] - check_nan_in_transition( - observations=observations, actions=actions, next_state=next_observations - ) + check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) observation_features, next_observation_features = get_observation_features( policy, observations, next_observations @@ -697,23 +663,15 @@ def add_actor_information_and_train( if optimization_step % log_freq == 0: training_infos["replay_buffer_size"] = len(replay_buffer) if offline_replay_buffer is not None: - training_infos["offline_replay_buffer_size"] = len( - offline_replay_buffer - ) + training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer) training_infos["Optimization step"] = optimization_step - logger.log_dict( - d=training_infos, mode="train", custom_step_key="Optimization step" - ) + logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step") # logging.info(f"Training infos: {training_infos}") time_for_one_optimization_step = time.time() - time_for_one_optimization_step - frequency_for_one_optimization_step = 1 / ( - time_for_one_optimization_step + 1e-9 - ) + frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9) - logging.info( - f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}" - ) + logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}") logger.log_dict( { @@ -728,16 +686,12 @@ def add_actor_information_and_train( if optimization_step % log_freq == 0: logging.info(f"[LEARNER] Number of optimization step: {optimization_step}") - if save_checkpoint and ( - optimization_step % save_freq == 0 or optimization_step == online_steps - ): + if save_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps): logging.info(f"Checkpoint policy after step {optimization_step}") _num_digits = max(6, len(str(online_steps))) step_identifier = f"{optimization_step:0{_num_digits}d}" interaction_step = ( - interaction_message["Interaction step"] - if interaction_message is not None - else 0 + interaction_message["Interaction step"] if interaction_message is not None else 0 ) logger.save_checkpoint( optimization_step, @@ -755,9 +709,7 @@ def add_actor_information_and_train( shutil.rmtree( dataset_dir, ) - replay_buffer.to_lerobot_dataset( - dataset_repo_id, fps=fps, root=logger.log_dir / "dataset" - ) + replay_buffer.to_lerobot_dataset(dataset_repo_id, fps=fps, root=logger.log_dir / "dataset") if offline_replay_buffer is not None: dataset_dir = logger.log_dir / "dataset_offline" @@ -809,9 +761,7 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module): optimizer_critic = torch.optim.Adam( params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr ) - optimizer_temperature = torch.optim.Adam( - params=[policy.log_alpha], lr=policy.config.critic_lr - ) + optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr) lr_scheduler = None optimizers = { "actor": optimizer_actor, diff --git a/lerobot/scripts/server/learner_service.py b/lerobot/scripts/server/learner_service.py index b1f91cdc..2aabf275 100644 --- a/lerobot/scripts/server/learner_service.py +++ b/lerobot/scripts/server/learner_service.py @@ -1,10 +1,10 @@ -import hilserl_pb2 # type: ignore -import hilserl_pb2_grpc # type: ignore import logging from multiprocessing import Event, Queue -from lerobot.scripts.server.network_utils import receive_bytes_in_chunks -from lerobot.scripts.server.network_utils import send_bytes_in_chunks +import hilserl_pb2 # type: ignore +import hilserl_pb2_grpc # type: ignore + +from lerobot.scripts.server.network_utils import receive_bytes_in_chunks, send_bytes_in_chunks MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB MAX_WORKERS = 3 # Stream parameters, send transitions and interactions @@ -64,9 +64,7 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer): def SendInteractions(self, request_iterator, _context): # TODO: authorize the request - logging.info( - "[LEARNER] Received request to receive interactions from the Actor" - ) + logging.info("[LEARNER] Received request to receive interactions from the Actor") receive_bytes_in_chunks( request_iterator, diff --git a/lerobot/scripts/server/maniskill_manipulator.py b/lerobot/scripts/server/maniskill_manipulator.py index 495042de..9db7aa40 100644 --- a/lerobot/scripts/server/maniskill_manipulator.py +++ b/lerobot/scripts/server/maniskill_manipulator.py @@ -1,12 +1,12 @@ -import einops -import numpy as np -import gymnasium as gym -import torch - -from omegaconf import DictConfig from typing import Any -from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv + +import einops +import gymnasium as gym +import numpy as np +import torch from mani_skill.utils.wrappers.record import RecordEpisode +from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv +from omegaconf import DictConfig def preprocess_maniskill_observation( @@ -63,9 +63,7 @@ class ManiSkillCompat(gym.Wrapper): new_action_space_shape = env.action_space.shape[-1] new_low = np.squeeze(env.action_space.low, axis=0) new_high = np.squeeze(env.action_space.high, axis=0) - self.action_space = gym.spaces.Box( - low=new_low, high=new_high, shape=(new_action_space_shape,) - ) + self.action_space = gym.spaces.Box(low=new_low, high=new_high, shape=(new_action_space_shape,)) def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None @@ -84,9 +82,7 @@ class ManiSkillCompat(gym.Wrapper): class ManiSkillActionWrapper(gym.ActionWrapper): def __init__(self, env): super().__init__(env) - self.action_space = gym.spaces.Tuple( - spaces=(env.action_space, gym.spaces.Discrete(2)) - ) + self.action_space = gym.spaces.Tuple(spaces=(env.action_space, gym.spaces.Discrete(2))) def action(self, action): action, telop = action @@ -100,9 +96,7 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper): action_space_agent: gym.spaces.Box = env.action_space[0] action_space_agent.low = action_space_agent.low * multiply_factor action_space_agent.high = action_space_agent.high * multiply_factor - self.action_space = gym.spaces.Tuple( - spaces=(action_space_agent, gym.spaces.Discrete(2)) - ) + self.action_space = gym.spaces.Tuple(spaces=(action_space_agent, gym.spaces.Discrete(2))) def step(self, action): if isinstance(action, tuple): @@ -153,9 +147,7 @@ def make_maniskill( ) env = ManiSkillObservationWrapper(env, device=cfg.env.device) env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False) - env._max_episode_steps = env.max_episode_steps = ( - 50 # gym_utils.find_max_episode_steps_value(env) - ) + env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env) env.unwrapped.metadata["render_fps"] = 20 env = ManiSkillCompat(env) env = ManiSkillActionWrapper(env) @@ -166,12 +158,11 @@ def make_maniskill( if __name__ == "__main__": import argparse + import hydra parser = argparse.ArgumentParser() - parser.add_argument( - "--config", type=str, default="lerobot/configs/env/maniskill_example.yaml" - ) + parser.add_argument("--config", type=str, default="lerobot/configs/env/maniskill_example.yaml") args = parser.parse_args() # Initialize config diff --git a/lerobot/scripts/server/network_utils.py b/lerobot/scripts/server/network_utils.py index 03ca06ca..78b9e5db 100644 --- a/lerobot/scripts/server/network_utils.py +++ b/lerobot/scripts/server/network_utils.py @@ -15,12 +15,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from lerobot.scripts.server import hilserl_pb2 -import logging import io -from multiprocessing import Queue, Event +import logging +from multiprocessing import Event, Queue from typing import Any +from lerobot.scripts.server import hilserl_pb2 + CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB @@ -31,9 +32,7 @@ def bytes_buffer_size(buffer: io.BytesIO) -> int: return result -def send_bytes_in_chunks( - buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True -): +def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True): buffer = io.BytesIO(buffer) size_in_bytes = bytes_buffer_size(buffer) @@ -56,16 +55,12 @@ def send_bytes_in_chunks( yield message_class(transfer_state=transfer_state, data=chunk) sent_bytes += size_to_read - logging_method( - f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}" - ) + logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}") logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB") -def receive_bytes_in_chunks( - iterator, queue: Queue, shutdown_event: Event, log_prefix: str = "" -): +def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""): bytes_buffer = io.BytesIO() step = 0 @@ -89,9 +84,7 @@ def receive_bytes_in_chunks( logging.debug(f"{log_prefix} Received data at step {step}") elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_END: bytes_buffer.write(item.data) - logging.debug( - f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}" - ) + logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}") queue.put(bytes_buffer.getvalue()) diff --git a/lerobot/scripts/server/utils.py b/lerobot/scripts/server/utils.py index 699717e4..2ce4e57f 100644 --- a/lerobot/scripts/server/utils.py +++ b/lerobot/scripts/server/utils.py @@ -18,9 +18,10 @@ import logging import signal import sys -from torch.multiprocessing import Queue from queue import Empty +from torch.multiprocessing import Queue + shutdown_event_counter = 0 diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 878b0111..0bdf61e7 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -223,18 +223,12 @@ def train(cfg: TrainPipelineConfig): step = 0 # number of policy updates (forward + backward + optim) if cfg.resume: - step, optimizer, lr_scheduler = load_training_state( - cfg.checkpoint_path, optimizer, lr_scheduler - ) + step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler) - num_learnable_params = sum( - p.numel() for p in policy.parameters() if p.requires_grad - ) + num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) - logging.info( - colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}" - ) + logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") if cfg.env is not None: logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})") @@ -335,9 +329,7 @@ def train(cfg: TrainPipelineConfig): logging.info(f"Eval policy at step {step}") with ( torch.no_grad(), - torch.autocast(device_type=device.type) - if cfg.policy.use_amp - else nullcontext(), + torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(), ): eval_info = eval_policy( eval_env, diff --git a/lerobot/scripts/train_hilserl_classifier.py b/lerobot/scripts/train_hilserl_classifier.py index 5d877bba..fc9a1d07 100644 --- a/lerobot/scripts/train_hilserl_classifier.py +++ b/lerobot/scripts/train_hilserl_classifier.py @@ -52,19 +52,13 @@ def get_model(cfg, logger): # noqa I001 classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg) model = Classifier(classifier_config) if cfg.resume: - model.load_state_dict( - Classifier.from_pretrained( - str(logger.last_pretrained_model_dir) - ).state_dict() - ) + model.load_state_dict(Classifier.from_pretrained(str(logger.last_pretrained_model_dir)).state_dict()) return model def create_balanced_sampler(dataset, cfg): # Get underlying dataset if using Subset - original_dataset = ( - dataset.dataset if isinstance(dataset, torch.utils.data.Subset) else dataset - ) + original_dataset = dataset.dataset if isinstance(dataset, torch.utils.data.Subset) else dataset # Get indices if using Subset (for slicing) indices = dataset.indices if isinstance(dataset, torch.utils.data.Subset) else None @@ -83,9 +77,7 @@ def create_balanced_sampler(dataset, cfg): class_weights = 1.0 / counts.float() sample_weights = class_weights[labels] - return WeightedRandomSampler( - weights=sample_weights, num_samples=len(sample_weights), replacement=True - ) + return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True) def support_amp(device: torch.device, cfg: DictConfig) -> bool: @@ -94,9 +86,7 @@ def support_amp(device: torch.device, cfg: DictConfig) -> bool: return cfg.training.use_amp and device.type in ("cuda", "cpu") -def train_epoch( - model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg -): +def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg): # Single epoch training loop with AMP support and progress tracking model.train() correct = 0 @@ -110,11 +100,7 @@ def train_epoch( labels = batch[cfg.training.label_key].float().to(device) # Forward pass with optional AMP - with ( - torch.autocast(device_type=device.type) - if support_amp(device, cfg) - else nullcontext() - ): + with torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(): outputs = model(images) loss = criterion(outputs.logits, labels) @@ -159,9 +145,7 @@ def validate(model, val_loader, criterion, device, logger, cfg): with ( torch.no_grad(), - torch.autocast(device_type=device.type) - if support_amp(device, cfg) - else nullcontext(), + torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(), ): for batch in tqdm(val_loader, desc="Validation"): images = [batch[img_key].to(device) for img_key in cfg.training.image_keys] @@ -174,9 +158,7 @@ def validate(model, val_loader, criterion, device, logger, cfg): ): outputs = model(images) inference_times.append( - next( - x for x in prof.key_averages() if x.key == "model_inference" - ).cpu_time + next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time ) else: outputs = model(images) @@ -194,24 +176,16 @@ def validate(model, val_loader, criterion, device, logger, cfg): # Log sample predictions for visualization if len(samples) < cfg.eval.num_samples_to_log: - for i in range( - min(cfg.eval.num_samples_to_log - len(samples), len(images)) - ): + for i in range(min(cfg.eval.num_samples_to_log - len(samples), len(images))): if model.config.num_classes == 2: confidence = round(outputs.probabilities[i].item(), 3) else: - confidence = [ - round(prob, 3) for prob in outputs.probabilities[i].tolist() - ] + confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()] samples.append( { **{ - f"image_{img_key}": wandb.Image( - images[img_idx][i].cpu() - ) - for img_idx, img_key in enumerate( - cfg.training.image_keys - ) + f"image_{img_key}": wandb.Image(images[img_idx][i].cpu()) + for img_idx, img_key in enumerate(cfg.training.image_keys) }, "true_label": labels[i].item(), "predicted": predictions[i].item(), @@ -286,9 +260,7 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step): _ = model(x) inference_times.append( - next( - x for x in prof.key_averages() if x.key == "model_inference" - ).cpu_time + next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time ) inference_times = np.array(inference_times) @@ -314,9 +286,7 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step): return avg, median, std -def train( - cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None -) -> None: +def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None) -> None: if out_dir is None: raise NotImplementedError() if job_name is None: @@ -372,9 +342,7 @@ def train( "You have set resume=True, but there is no model checkpoint in " f"{Logger.get_last_checkpoint_dir(out_dir)}" ) - checkpoint_cfg_path = str( - Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml" - ) + checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml") logging.info( colored( "You have set resume=True, indicating that you wish to resume a run", @@ -387,9 +355,7 @@ def train( # Check for differences between the checkpoint configuration and provided configuration. # Hack to resolve the delta_timestamps ahead of time in order to properly diff. resolve_delta_timestamps(cfg) - diff = DeepDiff( - OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg) - ) + diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)) # Ignore the `resume` and parameters. if "values_changed" in diff and "root['resume']" in diff["values_changed"]: del diff["values_changed"]["root['resume']"] @@ -408,11 +374,7 @@ def train( optimizer = optim.AdamW(model.parameters(), lr=cfg.training.learning_rate) # Use BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multi-class - criterion = ( - nn.BCEWithLogitsLoss() - if model.config.num_classes == 2 - else nn.CrossEntropyLoss() - ) + criterion = nn.BCEWithLogitsLoss() if model.config.num_classes == 2 else nn.CrossEntropyLoss() grad_scaler = GradScaler(enabled=cfg.training.use_amp) # Log model parameters diff --git a/lerobot/scripts/train_sac.py b/lerobot/scripts/train_sac.py index cfd05f62..f2ce67f0 100644 --- a/lerobot/scripts/train_sac.py +++ b/lerobot/scripts/train_sac.py @@ -52,9 +52,7 @@ def make_optimizers_and_scheduler(cfg, policy): params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr ) # We wrap policy log temperature in list because this is a torch tensor and not a nn.Module - optimizer_temperature = torch.optim.Adam( - params=[policy.log_alpha], lr=policy.config.critic_lr - ) + optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr) lr_scheduler = None optimizers = { "actor": optimizer_actor, @@ -106,9 +104,7 @@ def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Te images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C) # Gather pixels - cropped_hwcn = images_hwcn[ - torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, : - ] + cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :] # cropped_hwcn => (B, crop_h, crop_w, C) cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w) @@ -198,12 +194,8 @@ class ReplayBuffer: """ # We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from # a replay buffer than from a lerobot dataset. - replay_buffer = cls( - capacity=len(lerobot_dataset), device=device, state_keys=state_keys - ) - list_transition = cls._lerobotdataset_to_transitions( - dataset=lerobot_dataset, state_keys=state_keys - ) + replay_buffer = cls(capacity=len(lerobot_dataset), device=device, state_keys=state_keys) + list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys) # Fill the replay buffer with the lerobot dataset transitions for data in list_transition: replay_buffer.add( @@ -248,9 +240,7 @@ class ReplayBuffer: # If not provided, you can either raise an error or define a default: if state_keys is None: - raise ValueError( - "You must provide a list of keys in `state_keys` that define your 'state'." - ) + raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.") transitions: list[Transition] = [] num_frames = len(dataset) @@ -304,40 +294,36 @@ class ReplayBuffer: # -- Build batched states -- batch_state = {} for key in self.state_keys: - batch_state[key] = torch.cat( - [t["state"][key] for t in list_of_transitions], dim=0 - ).to(self.device) + batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to( + self.device + ) if key.startswith("observation.image") and self.use_drq: batch_state[key] = self.image_augmentation_function(batch_state[key]) # -- Build batched actions -- - batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to( - self.device - ) + batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device) # -- Build batched rewards -- - batch_rewards = torch.tensor( - [t["reward"] for t in list_of_transitions], dtype=torch.float32 - ).to(self.device) + batch_rewards = torch.tensor([t["reward"] for t in list_of_transitions], dtype=torch.float32).to( + self.device + ) # -- Build batched next states -- batch_next_state = {} for key in self.state_keys: - batch_next_state[key] = torch.cat( - [t["next_state"][key] for t in list_of_transitions], dim=0 - ).to(self.device) + batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to( + self.device + ) if key.startswith("observation.image") and self.use_drq: - batch_next_state[key] = self.image_augmentation_function( - batch_next_state[key] - ) + batch_next_state[key] = self.image_augmentation_function(batch_next_state[key]) # -- Build batched dones -- - batch_dones = torch.tensor( - [t["done"] for t in list_of_transitions], dtype=torch.float32 - ).to(self.device) - batch_dones = torch.tensor( - [t["done"] for t in list_of_transitions], dtype=torch.float32 - ).to(self.device) + batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to( + self.device + ) + batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to( + self.device + ) # Return a BatchTransition typed dict return BatchTransition( @@ -427,9 +413,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # dataset_stats=offline_dataset.meta.stats if not cfg.resume else None, # Hack: But if we do online traning, we do not need dataset_stats dataset_stats=None, - pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) - if cfg.resume - else None, + pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None, device=device, ) assert isinstance(policy, nn.Module) @@ -438,9 +422,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # TODO: Handle resume - num_learnable_params = sum( - p.numel() for p in policy.parameters() if p.requires_grad - ) + num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) log_output_dir(out_dir) @@ -481,16 +463,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No if interaction_step >= cfg.training.online_step_before_learning: action = policy.select_action(batch=obs) - next_obs, reward, done, truncated, info = online_env.step( - action.cpu().numpy() - ) + next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy()) else: action = online_env.action_space.sample() next_obs, reward, done, truncated, info = online_env.step(action) # HACK - action = torch.tensor(action, dtype=torch.float32).to( - device, non_blocking=True - ) + action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True) # HACK: For maniskill # next_obs = preprocess_observation(next_obs) @@ -500,20 +478,14 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # Because we are using a single environment # we can safely assume that the episode is done if done[0] or truncated[0]: - logging.info( - f"Global step {interaction_step}: Episode reward: {sum_reward_episode}" - ) - logger.log_dict( - {"Sum episode reward": sum_reward_episode}, interaction_step - ) + logging.info(f"Global step {interaction_step}: Episode reward: {sum_reward_episode}") + logger.log_dict({"Sum episode reward": sum_reward_episode}, interaction_step) sum_reward_episode = 0 # HACK: This is for maniskill logging.info( f"global step {interaction_step}: episode success: {info['success'].float().item()} \n" ) - logger.log_dict( - {"Episode success": info["success"].float().item()}, interaction_step - ) + logger.log_dict({"Episode success": info["success"].float().item()}, interaction_step) replay_buffer.add( state=obs, @@ -587,9 +559,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No training_infos["loss_actor"] = loss_actor.item() - loss_temperature = policy.compute_loss_temperature( - observations=observations - ) + loss_temperature = policy.compute_loss_temperature(observations=observations) optimizers["temperature"].zero_grad() loss_temperature.backward() optimizers["temperature"].step() @@ -611,9 +581,7 @@ def train_cli(cfg: dict): ) -def train_notebook( - out_dir=None, job_name=None, config_name="default", config_path="../configs" -): +def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"): from hydra import compose, initialize hydra.core.global_hydra.GlobalHydra.instance().clear() diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 65779d85..cdfea6b8 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -94,12 +94,8 @@ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray: assert chw_float32_torch.dtype == torch.float32 assert chw_float32_torch.ndim == 3 c, h, w = chw_float32_torch.shape - assert c < h and c < w, ( - f"expect channel first images, but instead {chw_float32_torch.shape}" - ) - hwc_uint8_numpy = ( - (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy() - ) + assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}" + hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy() return hwc_uint8_numpy diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py index ade5b7e5..ef0385e7 100644 --- a/lerobot/scripts/visualize_dataset_html.py +++ b/lerobot/scripts/visualize_dataset_html.py @@ -142,12 +142,8 @@ def run_server( ) ) - @app.route( - "///episode_" - ) - def show_episode( - dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes - ): + @app.route("///episode_") + def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes): repo_id = f"{dataset_namespace}/{dataset_name}" try: if dataset is None: @@ -158,9 +154,7 @@ def run_server( 400, ) dataset_version = ( - str(dataset.meta._version) - if isinstance(dataset, LeRobotDataset) - else dataset.codebase_version + str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version ) match = re.search(r"v(\d+)\.", dataset_version) if match: @@ -168,9 +162,7 @@ def run_server( if major_version < 2: return "Make sure to convert your LeRobotDataset to v2 & above." - episode_data_csv_str, columns, ignored_columns = get_episode_data( - dataset, episode_id - ) + episode_data_csv_str, columns, ignored_columns = get_episode_data(dataset, episode_id) dataset_info = { "repo_id": f"{dataset_namespace}/{dataset_name}", "num_samples": dataset.num_frames @@ -183,8 +175,7 @@ def run_server( } if isinstance(dataset, LeRobotDataset): video_paths = [ - dataset.meta.get_video_file_path(episode_id, key) - for key in dataset.meta.video_keys + dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys ] videos_info = [ { @@ -197,9 +188,7 @@ def run_server( ] tasks = dataset.meta.episodes[episode_id]["tasks"] else: - video_keys = [ - key for key, ft in dataset.features.items() if ft["dtype"] == "video" - ] + video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"] videos_info = [ { "url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/" @@ -219,24 +208,16 @@ def run_server( ) response.raise_for_status() # Split into lines and parse each line as JSON - tasks_jsonl = [ - json.loads(line) for line in response.text.splitlines() if line.strip() - ] + tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()] - filtered_tasks_jsonl = [ - row for row in tasks_jsonl if row["episode_index"] == episode_id - ] + filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id] tasks = filtered_tasks_jsonl[0]["tasks"] videos_info[0]["language_instruction"] = tasks if episodes is None: episodes = list( - range( - dataset.num_episodes - if isinstance(dataset, LeRobotDataset) - else dataset.total_episodes - ) + range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes) ) return render_template( @@ -263,11 +244,7 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index) This file will be loaded by Dygraph javascript to plot data in real time.""" columns = [] - selected_columns = [ - col - for col, ft in dataset.features.items() - if ft["dtype"] in ["float32", "int32"] - ] + selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]] selected_columns.remove("timestamp") ignored_columns = [] @@ -288,10 +265,7 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index) else dataset.features[column_name].shape[0] ) - if ( - "names" in dataset.features[column_name] - and dataset.features[column_name]["names"] - ): + if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]: column_names = dataset.features[column_name]["names"] while not isinstance(column_names, list): column_names = list(column_names.values())[0] @@ -314,12 +288,9 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index) else: repo_id = dataset.repo_id - url = ( - f"https://huggingface.co/datasets/{repo_id}/resolve/main/" - + dataset.data_path.format( - episode_chunk=int(episode_index) // dataset.chunks_size, - episode_index=episode_index, - ) + url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format( + episode_chunk=int(episode_index) // dataset.chunks_size, + episode_index=episode_index, ) df = pd.read_parquet(url) data = df[selected_columns] # Select specific columns @@ -352,9 +323,7 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str] ] -def get_episode_language_instruction( - dataset: LeRobotDataset, ep_index: int -) -> list[str]: +def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]: # check if the dataset has language instructions if "language_instruction" not in dataset.features: return None @@ -365,9 +334,7 @@ def get_episode_language_instruction( language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"] # TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored # with the tf.tensor appearing in the string - return language_instruction.removeprefix("tf.Tensor(b'").removesuffix( - "', shape=(), dtype=string)" - ) + return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)") def get_dataset_info(repo_id: str) -> IterableNamespace: @@ -403,9 +370,7 @@ def visualize_dataset_html( if force_override: shutil.rmtree(output_dir) else: - logging.info( - f"Output directory already exists. Loading from it: '{output_dir}'" - ) + logging.info(f"Output directory already exists. Loading from it: '{output_dir}'") output_dir.mkdir(parents=True, exist_ok=True) diff --git a/lerobot/scripts/visualize_image_transforms.py b/lerobot/scripts/visualize_image_transforms.py index 4b6bb900..80935d32 100644 --- a/lerobot/scripts/visualize_image_transforms.py +++ b/lerobot/scripts/visualize_image_transforms.py @@ -47,9 +47,7 @@ OUTPUT_DIR = Path("outputs/image_transforms") to_pil = ToPILImage() -def save_all_transforms( - cfg: ImageTransformsConfig, original_frame, output_dir, n_examples -): +def save_all_transforms(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples): output_dir_all = output_dir / "all" output_dir_all.mkdir(parents=True, exist_ok=True) @@ -62,9 +60,7 @@ def save_all_transforms( print(f" {output_dir_all}") -def save_each_transform( - cfg: ImageTransformsConfig, original_frame, output_dir, n_examples -): +def save_each_transform(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples): if not cfg.enable: logging.warning( "No single transforms will be saved, because `image_transforms.enable=False`. To enable, set `enable` to True in `ImageTransformsConfig` or in the command line with `--image_transforms.enable=True`." @@ -93,15 +89,9 @@ def save_each_transform( tf_cfg_kwgs_max[key] = [max_, max_] tf_cfg_kwgs_avg[key] = [avg, avg] - tf_min = make_transform_from_config( - replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_min}) - ) - tf_max = make_transform_from_config( - replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_max}) - ) - tf_avg = make_transform_from_config( - replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_avg}) - ) + tf_min = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_min})) + tf_max = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_max})) + tf_avg = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_avg})) tf_frame_min = tf_min(original_frame) tf_frame_max = tf_max(original_frame) @@ -115,9 +105,7 @@ def save_each_transform( @draccus.wrap() -def visualize_image_transforms( - cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR, n_examples: int = 5 -): +def visualize_image_transforms(cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR, n_examples: int = 5): dataset = LeRobotDataset( repo_id=cfg.repo_id, episodes=cfg.episodes, diff --git a/tests/artifacts/datasets/save_dataset_to_safetensors.py b/tests/artifacts/datasets/save_dataset_to_safetensors.py index 3159e8c8..74d42a3d 100644 --- a/tests/artifacts/datasets/save_dataset_to_safetensors.py +++ b/tests/artifacts/datasets/save_dataset_to_safetensors.py @@ -52,13 +52,7 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"): save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors") # save 2 frames at the middle of first episode - i = int( - ( - dataset.episode_data_index["to"][0].item() - - dataset.episode_data_index["from"][0].item() - ) - / 2 - ) + i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2) save_file(dataset[i], repo_dir / f"frame_{i}.safetensors") save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors") diff --git a/tests/artifacts/policies/save_policy_to_safetensors.py b/tests/artifacts/policies/save_policy_to_safetensors.py index c1465074..106f0dc0 100644 --- a/tests/artifacts/policies/save_policy_to_safetensors.py +++ b/tests/artifacts/policies/save_policy_to_safetensors.py @@ -51,9 +51,7 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): batch = next(iter(dataloader)) loss, output_dict = policy.forward(batch) if output_dict is not None: - output_dict = { - k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor) - } + output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)} output_dict["loss"] = loss else: output_dict = {"loss": loss} @@ -71,9 +69,7 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): param_stats = {} for key, param in policy.named_parameters(): param_stats[f"{key}_mean"] = param.mean() - param_stats[f"{key}_std"] = ( - param.std() if param.numel() > 1 else torch.tensor(float(0.0)) - ) + param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0)) optimizer.zero_grad() policy.reset() @@ -100,15 +96,11 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): else: actions_queue = train_cfg.policy.n_action_repeats - actions = { - str(i): policy.select_action(obs).contiguous() for i in range(actions_queue) - } + actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)} return output_dict, grad_stats, param_stats, actions -def save_policy_to_safetensors( - output_dir: Path, ds_repo_id: str, policy_name: str, policy_kwargs: dict -): +def save_policy_to_safetensors(output_dir: Path, ds_repo_id: str, policy_name: str, policy_kwargs: dict): if output_dir.exists(): print(f"Overwrite existing safetensors in '{output_dir}':") print(f" - Validate with: `git add {output_dir}`") @@ -116,9 +108,7 @@ def save_policy_to_safetensors( shutil.rmtree(output_dir) output_dir.mkdir(parents=True, exist_ok=True) - output_dict, grad_stats, param_stats, actions = get_policy_stats( - ds_repo_id, policy_name, policy_kwargs - ) + output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs) save_file(output_dict, output_dir / "output_dict.safetensors") save_file(grad_stats, output_dir / "grad_stats.safetensors") save_file(param_stats, output_dir / "param_stats.safetensors") @@ -151,7 +141,5 @@ if __name__ == "__main__": raise RuntimeError("No policies were provided!") for ds_repo_id, policy, policy_kwargs, file_name_extra in artifacts_cfg: ds_name = ds_repo_id.split("/")[-1] - output_dir = ( - Path("tests/artifacts/policies") / f"{ds_name}_{policy}_{file_name_extra}" - ) + output_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy}_{file_name_extra}" save_policy_to_safetensors(output_dir, ds_repo_id, policy, policy_kwargs) diff --git a/tests/cameras/mock_pyrealsense2.py b/tests/cameras/mock_pyrealsense2.py index 38baf5ba..c477eb06 100644 --- a/tests/cameras/mock_pyrealsense2.py +++ b/tests/cameras/mock_pyrealsense2.py @@ -30,9 +30,7 @@ class config: # noqa: N801 def enable_device(self, device_id: str): self.device_enabled = device_id - def enable_stream( - self, stream_type: stream, width=None, height=None, color_format=None, fps=None - ): + def enable_stream(self, stream_type: stream, width=None, height=None, color_format=None, fps=None): self.stream_type = stream_type # Overwrite default values when possible self.width = 848 if width is None else width diff --git a/tests/configs/test_plugin_loading.py b/tests/configs/test_plugin_loading.py index 0b990fa3..1a8cceed 100644 --- a/tests/configs/test_plugin_loading.py +++ b/tests/configs/test_plugin_loading.py @@ -9,9 +9,7 @@ from lerobot.common.envs.configs import EnvConfig from lerobot.configs.parser import PluginLoadError, load_plugin, parse_plugin_args, wrap -def create_plugin_code( - *, base_class: str = "EnvConfig", plugin_name: str = "test_env" -) -> str: +def create_plugin_code(*, base_class: str = "EnvConfig", plugin_name: str = "test_env") -> str: """Creates a dummy plugin module that implements its own EnvConfig subclass.""" return f""" from dataclasses import dataclass diff --git a/tests/datasets/test_compute_stats.py b/tests/datasets/test_compute_stats.py index dbedd35c..e93c5274 100644 --- a/tests/datasets/test_compute_stats.py +++ b/tests/datasets/test_compute_stats.py @@ -31,11 +31,7 @@ from lerobot.common.datasets.compute_stats import ( def mock_load_image_as_numpy(path, dtype, channel_first): - return ( - np.ones((3, 32, 32), dtype=dtype) - if channel_first - else np.ones((32, 32, 3), dtype=dtype) - ) + return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype) @pytest.fixture @@ -81,20 +77,9 @@ def test_sample_images(mock_load): def test_get_feature_stats_images(): data = np.random.rand(100, 3, 32, 32) stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True) - assert ( - "min" in stats - and "max" in stats - and "mean" in stats - and "std" in stats - and "count" in stats - ) + assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats np.testing.assert_equal(stats["count"], np.array([100])) - assert ( - stats["min"].shape - == stats["max"].shape - == stats["mean"].shape - == stats["std"].shape - ) + assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape def test_get_feature_stats_axis_0_keepdims(sample_array): @@ -315,47 +300,31 @@ def test_aggregate_stats(): for ep_stats in all_stats: for fkey, stats in ep_stats.items(): for k in stats: - stats[k] = np.array( - stats[k], dtype=np.int64 if k == "count" else np.float32 - ) + stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32) if fkey == "observation.image" and k != "count": - stats[k] = stats[k].reshape( - 3, 1, 1 - ) # for normalization on image channels + stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels else: stats[k] = stats[k].reshape(1) # cast to numpy for fkey, stats in expected_agg_stats.items(): for k in stats: - stats[k] = np.array( - stats[k], dtype=np.int64 if k == "count" else np.float32 - ) + stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32) if fkey == "observation.image" and k != "count": - stats[k] = stats[k].reshape( - 3, 1, 1 - ) # for normalization on image channels + stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels else: stats[k] = stats[k].reshape(1) results = aggregate_stats(all_stats) for fkey in expected_agg_stats: - np.testing.assert_allclose( - results[fkey]["min"], expected_agg_stats[fkey]["min"] - ) - np.testing.assert_allclose( - results[fkey]["max"], expected_agg_stats[fkey]["max"] - ) - np.testing.assert_allclose( - results[fkey]["mean"], expected_agg_stats[fkey]["mean"] - ) + np.testing.assert_allclose(results[fkey]["min"], expected_agg_stats[fkey]["min"]) + np.testing.assert_allclose(results[fkey]["max"], expected_agg_stats[fkey]["max"]) + np.testing.assert_allclose(results[fkey]["mean"], expected_agg_stats[fkey]["mean"]) np.testing.assert_allclose( results[fkey]["std"], expected_agg_stats[fkey]["std"], atol=1e-04, rtol=1e-04, ) - np.testing.assert_allclose( - results[fkey]["count"], expected_agg_stats[fkey]["count"] - ) + np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"]) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 571b1fcc..38f722dc 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -72,9 +72,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): # Instantiate both ways robot = make_robot("koch", mock=True) root_create = tmp_path / "create" - dataset_create = LeRobotDataset.create( - repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create - ) + dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create) root_init = tmp_path / "init" dataset_init = lerobot_dataset_factory(root=root_init) @@ -129,9 +127,7 @@ def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory): ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n", ): - dataset.add_frame( - {"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"} - ) + dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"}) def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory): @@ -141,9 +137,7 @@ def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory): ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n", ): - dataset.add_frame( - {"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"} - ) + dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"}) def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory): @@ -151,9 +145,7 @@ def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) with pytest.raises( ValueError, - match=re.escape( - "The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n" - ), + match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"), ): dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"}) @@ -175,9 +167,7 @@ def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_fact dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) with pytest.raises( ValueError, - match=re.escape( - "The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n" - ), + match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"), ): dataset.add_frame({"state": torch.tensor(1.0), "task": "Dummy task"}) @@ -471,9 +461,7 @@ def test_flatten_unflatten_dict(): d = unflatten_dict(flatten_dict(d)) # test equality between nested dicts - assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), ( - f"{original_d} != {d}" - ) + assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}" @pytest.mark.parametrize( @@ -527,13 +515,7 @@ def test_backward_compatibility(repo_id): load_and_compare(i + 1) # test 2 frames at the middle of first episode - i = int( - ( - dataset.episode_data_index["to"][0].item() - - dataset.episode_data_index["from"][0].item() - ) - / 2 - ) + i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2) load_and_compare(i) load_and_compare(i + 1) diff --git a/tests/datasets/test_delta_timestamps.py b/tests/datasets/test_delta_timestamps.py index 62875116..59b90b4c 100644 --- a/tests/datasets/test_delta_timestamps.py +++ b/tests/datasets/test_delta_timestamps.py @@ -71,12 +71,8 @@ def unsynced_timestamps_factory(synced_timestamps_factory): def _create_unsynced_timestamps( fps: int = 30, tolerance_s: float = 1e-4 ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - timestamps, episode_indices, episode_data_index = synced_timestamps_factory( - fps=fps - ) - timestamps[30] += ( - tolerance_s * 1.1 - ) # Modify a single timestamp just outside tolerance + timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps) + timestamps[30] += tolerance_s * 1.1 # Modify a single timestamp just outside tolerance return timestamps, episode_indices, episode_data_index return _create_unsynced_timestamps @@ -87,12 +83,8 @@ def slightly_off_timestamps_factory(synced_timestamps_factory): def _create_slightly_off_timestamps( fps: int = 30, tolerance_s: float = 1e-4 ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - timestamps, episode_indices, episode_data_index = synced_timestamps_factory( - fps=fps - ) - timestamps[30] += ( - tolerance_s * 0.9 - ) # Modify a single timestamp just inside tolerance + timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps) + timestamps[30] += tolerance_s * 0.9 # Modify a single timestamp just inside tolerance return timestamps, episode_indices, episode_data_index return _create_slightly_off_timestamps @@ -105,9 +97,7 @@ def valid_delta_timestamps_factory(): keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10), ) -> dict: - delta_timestamps = { - key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys - } + delta_timestamps = {key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys} return delta_timestamps return _create_valid_delta_timestamps @@ -144,9 +134,7 @@ def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory): @pytest.fixture(scope="module") def delta_indices_factory(): - def _delta_indices( - keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10) - ) -> dict: + def _delta_indices(keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)) -> dict: return {key: list(range(*min_max_range)) for key in keys} return _delta_indices @@ -198,9 +186,7 @@ def test_check_timestamps_sync_unsynced_no_exception(unsynced_timestamps_factory def test_check_timestamps_sync_slightly_off(slightly_off_timestamps_factory): fps = 30 tolerance_s = 1e-4 - timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory( - fps, tolerance_s - ) + timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory(fps, tolerance_s) result = check_timestamps_sync( timestamps=timestamps, episode_indices=ep_idx, @@ -241,9 +227,7 @@ def test_check_delta_timestamps_valid(valid_delta_timestamps_factory): def test_check_delta_timestamps_slightly_off(slightly_off_delta_timestamps_factory): fps = 30 tolerance_s = 1e-4 - slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory( - fps, tolerance_s - ) + slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory(fps, tolerance_s) result = check_delta_timestamps( delta_timestamps=slightly_off_delta_timestamps, fps=fps, diff --git a/tests/datasets/test_image_transforms.py b/tests/datasets/test_image_transforms.py index 1252d6a7..e896828f 100644 --- a/tests/datasets/test_image_transforms.py +++ b/tests/datasets/test_image_transforms.py @@ -82,11 +82,7 @@ def test_get_image_transforms_brightness(img_tensor_factory, min_max): img_tensor = img_tensor_factory() tf_cfg = ImageTransformsConfig( enable=True, - tfs={ - "brightness": ImageTransformConfig( - type="ColorJitter", kwargs={"brightness": min_max} - ) - }, + tfs={"brightness": ImageTransformConfig(type="ColorJitter", kwargs={"brightness": min_max})}, ) tf_actual = ImageTransforms(tf_cfg) tf_expected = v2.ColorJitter(brightness=min_max) @@ -98,11 +94,7 @@ def test_get_image_transforms_contrast(img_tensor_factory, min_max): img_tensor = img_tensor_factory() tf_cfg = ImageTransformsConfig( enable=True, - tfs={ - "contrast": ImageTransformConfig( - type="ColorJitter", kwargs={"contrast": min_max} - ) - }, + tfs={"contrast": ImageTransformConfig(type="ColorJitter", kwargs={"contrast": min_max})}, ) tf_actual = ImageTransforms(tf_cfg) tf_expected = v2.ColorJitter(contrast=min_max) @@ -114,11 +106,7 @@ def test_get_image_transforms_saturation(img_tensor_factory, min_max): img_tensor = img_tensor_factory() tf_cfg = ImageTransformsConfig( enable=True, - tfs={ - "saturation": ImageTransformConfig( - type="ColorJitter", kwargs={"saturation": min_max} - ) - }, + tfs={"saturation": ImageTransformConfig(type="ColorJitter", kwargs={"saturation": min_max})}, ) tf_actual = ImageTransforms(tf_cfg) tf_expected = v2.ColorJitter(saturation=min_max) @@ -142,11 +130,7 @@ def test_get_image_transforms_sharpness(img_tensor_factory, min_max): img_tensor = img_tensor_factory() tf_cfg = ImageTransformsConfig( enable=True, - tfs={ - "sharpness": ImageTransformConfig( - type="SharpnessJitter", kwargs={"sharpness": min_max} - ) - }, + tfs={"sharpness": ImageTransformConfig(type="SharpnessJitter", kwargs={"sharpness": min_max})}, ) tf_actual = ImageTransforms(tf_cfg) tf_expected = SharpnessJitter(sharpness=min_max) @@ -362,9 +346,7 @@ def test_save_all_transforms(img_tensor_factory, tmp_path): # Check if the combined transforms directory exists and contains the right files combined_transforms_dir = tmp_path / "all" - assert combined_transforms_dir.exists(), ( - "Combined transforms directory was not created." - ) + assert combined_transforms_dir.exists(), "Combined transforms directory was not created." assert any(combined_transforms_dir.iterdir()), ( "No transformed images found in combined transforms directory." ) @@ -386,9 +368,7 @@ def test_save_each_transform(img_tensor_factory, tmp_path): for transform in transforms: transform_dir = tmp_path / transform assert transform_dir.exists(), f"{transform} directory was not created." - assert any(transform_dir.iterdir()), ( - f"No transformed images found in {transform} directory." - ) + assert any(transform_dir.iterdir()), f"No transformed images found in {transform} directory." # Check for specific files within each transform directory expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + [ diff --git a/tests/datasets/test_image_writer.py b/tests/datasets/test_image_writer.py index 6655b415..802fe0d3 100644 --- a/tests/datasets/test_image_writer.py +++ b/tests/datasets/test_image_writer.py @@ -187,9 +187,7 @@ def test_save_image_torch(tmp_path, img_tensor_factory): writer.wait_until_done() assert fpath.exists() saved_image = np.array(Image.open(fpath)) - expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype( - np.uint8 - ) + expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) assert np.array_equal(expected_image, saved_image) finally: writer.stop() @@ -204,9 +202,7 @@ def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory): writer.wait_until_done() assert fpath.exists() saved_image = np.array(Image.open(fpath)) - expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype( - np.uint8 - ) + expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) assert np.array_equal(expected_image, saved_image) finally: writer.stop() @@ -296,9 +292,7 @@ def test_wait_until_done(tmp_path, img_array_factory): writer = AsyncImageWriter(num_processes=0, num_threads=4) try: num_images = 100 - image_arrays = [ - img_array_factory(height=500, width=500) for _ in range(num_images) - ] + image_arrays = [img_array_factory(height=500, width=500) for _ in range(num_images)] fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)] for image_array, fpath in zip(image_arrays, fpaths, strict=True): fpath.parent.mkdir(parents=True, exist_ok=True) diff --git a/tests/datasets/test_online_buffer.py b/tests/datasets/test_online_buffer.py index 52cdf684..f05b77a9 100644 --- a/tests/datasets/test_online_buffer.py +++ b/tests/datasets/test_online_buffer.py @@ -44,23 +44,13 @@ def make_new_buffer( return buffer, write_dir -def make_spoof_data_frames( - n_episodes: int, n_frames_per_episode: int -) -> dict[str, np.ndarray]: +def make_spoof_data_frames(n_episodes: int, n_frames_per_episode: int) -> dict[str, np.ndarray]: new_data = { - data_key: np.arange( - n_frames_per_episode * n_episodes * np.prod(data_shape) - ).reshape(-1, *data_shape), + data_key: np.arange(n_frames_per_episode * n_episodes * np.prod(data_shape)).reshape(-1, *data_shape), OnlineBuffer.INDEX_KEY: np.arange(n_frames_per_episode * n_episodes), - OnlineBuffer.EPISODE_INDEX_KEY: np.repeat( - np.arange(n_episodes), n_frames_per_episode - ), - OnlineBuffer.FRAME_INDEX_KEY: np.tile( - np.arange(n_frames_per_episode), n_episodes - ), - OnlineBuffer.TIMESTAMP_KEY: np.tile( - np.arange(n_frames_per_episode) / fps, n_episodes - ), + OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(np.arange(n_episodes), n_frames_per_episode), + OnlineBuffer.FRAME_INDEX_KEY: np.tile(np.arange(n_frames_per_episode), n_episodes), + OnlineBuffer.TIMESTAMP_KEY: np.tile(np.arange(n_frames_per_episode) / fps, n_episodes), } return new_data @@ -176,9 +166,7 @@ def test_delta_timestamps_within_tolerance(): buffer.tolerance_s = 0.04 item = buffer[2] data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"] - torch.testing.assert_close( - data, torch.tensor([0, 2, 3]), msg="Data does not match expected values" - ) + torch.testing.assert_close(data, torch.tensor([0, 2, 3]), msg="Data does not match expected values") assert not is_pad.any(), "Unexpected padding detected" @@ -214,9 +202,7 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range(): buffer.tolerance_s = 0.04 item = buffer[2] data, is_pad = item["index"], item["index_is_pad"] - assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), ( - "Data does not match expected values" - ) + assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values" assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), ( "Padding does not match expected values" ) @@ -233,15 +219,11 @@ def test_compute_sampler_weights_trivial( online_dataset_size: int, online_sampling_ratio: float, ): - offline_dataset = lerobot_dataset_factory( - tmp_path, total_episodes=1, total_frames=offline_dataset_size - ) + offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=offline_dataset_size) online_dataset, _ = make_new_buffer() if online_dataset_size > 0: online_dataset.add_data( - make_spoof_data_frames( - n_episodes=2, n_frames_per_episode=online_dataset_size // 2 - ) + make_spoof_data_frames(n_episodes=2, n_frames_per_episode=online_dataset_size // 2) ) weights = compute_sampler_weights( @@ -252,26 +234,18 @@ def test_compute_sampler_weights_trivial( if offline_dataset_size == 0 or online_dataset_size == 0: expected_weights = torch.ones(offline_dataset_size + online_dataset_size) elif online_sampling_ratio == 0: - expected_weights = torch.cat( - [torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)] - ) + expected_weights = torch.cat([torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)]) elif online_sampling_ratio == 1: - expected_weights = torch.cat( - [torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)] - ) + expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)]) expected_weights /= expected_weights.sum() torch.testing.assert_close(weights, expected_weights) def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path): # Arbitrarily set small dataset sizes, making sure to have uneven sizes. - offline_dataset = lerobot_dataset_factory( - tmp_path, total_episodes=1, total_frames=4 - ) + offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4) online_dataset, _ = make_new_buffer() - online_dataset.add_data( - make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2) - ) + online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) online_sampling_ratio = 0.8 weights = compute_sampler_weights( offline_dataset, @@ -284,17 +258,11 @@ def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_p ) -def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n( - lerobot_dataset_factory, tmp_path -): +def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_dataset_factory, tmp_path): # Arbitrarily set small dataset sizes, making sure to have uneven sizes. - offline_dataset = lerobot_dataset_factory( - tmp_path, total_episodes=1, total_frames=4 - ) + offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4) online_dataset, _ = make_new_buffer() - online_dataset.add_data( - make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2) - ) + online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) weights = compute_sampler_weights( offline_dataset, online_dataset=online_dataset, @@ -309,13 +277,9 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n( def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path): """Note: test copied from test_sampler.""" - offline_dataset = lerobot_dataset_factory( - tmp_path, total_episodes=1, total_frames=2 - ) + offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=2) online_dataset, _ = make_new_buffer() - online_dataset.add_data( - make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2) - ) + online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)) weights = compute_sampler_weights( offline_dataset, @@ -324,6 +288,4 @@ def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp online_sampling_ratio=0.5, online_drop_n_last_frames=1, ) - torch.testing.assert_close( - weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0]) - ) + torch.testing.assert_close(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0])) diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index 26a40ad9..7cd106d1 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -58,9 +58,7 @@ def get_task_index(task_dicts: dict, task: str) -> int: @pytest.fixture(scope="session") def img_tensor_factory(): - def _create_img_tensor( - height=100, width=100, channels=3, dtype=torch.float32 - ) -> torch.Tensor: + def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor: return torch.rand((channels, height, width), dtype=dtype) return _create_img_tensor @@ -68,14 +66,10 @@ def img_tensor_factory(): @pytest.fixture(scope="session") def img_array_factory(): - def _create_img_array( - height=100, width=100, channels=3, dtype=np.uint8 - ) -> np.ndarray: + def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray: if np.issubdtype(dtype, np.unsignedinteger): # Int array in [0, 255] range - img_array = np.random.randint( - 0, 256, size=(height, width, channels), dtype=dtype - ) + img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype) elif np.issubdtype(dtype, np.floating): # Float array in [0, 1] range img_array = np.random.rand(height, width, channels).astype(dtype) @@ -104,13 +98,10 @@ def features_factory(): ) -> dict: if use_videos: camera_ft = { - key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} - for key, ft in camera_features.items() + key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items() } else: - camera_ft = { - key: {"dtype": "image", **ft} for key, ft in camera_features.items() - } + camera_ft = {key: {"dtype": "image", **ft} for key, ft in camera_features.items()} return { **motor_features, **camera_ft, @@ -231,9 +222,7 @@ def episodes_factory(tasks_factory): if total_episodes <= 0 or total_frames <= 0: raise ValueError("num_episodes and total_length must be positive integers.") if total_frames < total_episodes: - raise ValueError( - "total_length must be greater than or equal to num_episodes." - ) + raise ValueError("total_length must be greater than or equal to num_episodes.") if not tasks: min_tasks = 2 if multi_task else 1 @@ -241,14 +230,10 @@ def episodes_factory(tasks_factory): tasks = tasks_factory(total_tasks) if total_episodes < len(tasks) and not multi_task: - raise ValueError( - "The number of tasks should be less than the number of episodes." - ) + raise ValueError("The number of tasks should be less than the number of episodes.") # Generate random lengths that sum up to total_length - lengths = np.random.multinomial( - total_frames, [1 / total_episodes] * total_episodes - ).tolist() + lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist() tasks_list = [task_dict["task"] for task_dict in tasks.values()] num_tasks_available = len(tasks_list) @@ -256,13 +241,9 @@ def episodes_factory(tasks_factory): episodes = {} remaining_tasks = tasks_list.copy() for ep_idx in range(total_episodes): - num_tasks_in_episode = ( - random.randint(1, min(3, num_tasks_available)) if multi_task else 1 - ) + num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1 tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list - episode_tasks = random.sample( - tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)) - ) + episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample))) if remaining_tasks: for task in episode_tasks: remaining_tasks.remove(task) @@ -279,9 +260,7 @@ def episodes_factory(tasks_factory): @pytest.fixture(scope="session") -def hf_dataset_factory( - features_factory, tasks_factory, episodes_factory, img_array_factory -): +def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory): def _create_hf_dataset( features: dict | None = None, tasks: list[dict] | None = None, @@ -300,12 +279,8 @@ def hf_dataset_factory( episode_index_col = np.array([], dtype=np.int64) task_index = np.array([], dtype=np.int64) for ep_dict in episodes.values(): - timestamp_col = np.concatenate( - (timestamp_col, np.arange(ep_dict["length"]) / fps) - ) - frame_index_col = np.concatenate( - (frame_index_col, np.arange(ep_dict["length"], dtype=int)) - ) + timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps)) + frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int))) episode_index_col = np.concatenate( ( episode_index_col, @@ -313,9 +288,7 @@ def hf_dataset_factory( ) ) ep_task_index = get_task_index(tasks, ep_dict["tasks"][0]) - task_index = np.concatenate( - (task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)) - ) + task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int))) index_col = np.arange(len(episode_index_col)) @@ -327,9 +300,7 @@ def hf_dataset_factory( for _ in range(len(index_col)) ] elif ft["shape"][0] > 1 and ft["dtype"] != "video": - robot_cols[key] = np.random.random( - (len(index_col), ft["shape"][0]) - ).astype(ft["dtype"]) + robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"]) hf_features = get_hf_features_from_features(features) dataset = datasets.Dataset.from_dict( @@ -392,9 +363,7 @@ def lerobot_dataset_metadata_factory( episodes=episodes, ) with ( - patch( - "lerobot.common.datasets.lerobot_dataset.get_safe_version" - ) as mock_get_safe_version_patch, + patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch, patch( "lerobot.common.datasets.lerobot_dataset.snapshot_download" ) as mock_snapshot_download_patch, @@ -442,9 +411,7 @@ def lerobot_dataset_factory( if not stats: stats = stats_factory(features=info["features"]) if not episodes_stats: - episodes_stats = episodes_stats_factory( - features=info["features"], total_episodes=total_episodes - ) + episodes_stats = episodes_stats_factory(features=info["features"], total_episodes=total_episodes) if not tasks: tasks = tasks_factory(total_tasks=info["total_tasks"]) if not episode_dicts: @@ -455,9 +422,7 @@ def lerobot_dataset_factory( multi_task=multi_task, ) if not hf_dataset: - hf_dataset = hf_dataset_factory( - tasks=tasks, episodes=episode_dicts, fps=info["fps"] - ) + hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"]) mock_snapshot_download = mock_snapshot_download_factory( info=info, @@ -477,12 +442,8 @@ def lerobot_dataset_factory( episodes=episode_dicts, ) with ( - patch( - "lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata" - ) as mock_metadata_patch, - patch( - "lerobot.common.datasets.lerobot_dataset.get_safe_version" - ) as mock_get_safe_version_patch, + patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch, + patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch, patch( "lerobot.common.datasets.lerobot_dataset.snapshot_download" ) as mock_snapshot_download_patch, diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index 0ab33f84..d869586f 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -59,9 +59,7 @@ def stats_path(stats_factory): @pytest.fixture(scope="session") def episodes_stats_path(episodes_stats_factory): - def _create_episodes_stats_jsonl_file( - dir: Path, episodes_stats: list[dict] | None = None - ) -> Path: + def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path: if not episodes_stats: episodes_stats = episodes_stats_factory() fpath = dir / EPISODES_STATS_PATH diff --git a/tests/fixtures/hub.py b/tests/fixtures/hub.py index cb2195b1..42f4875d 100644 --- a/tests/fixtures/hub.py +++ b/tests/fixtures/hub.py @@ -72,16 +72,12 @@ def mock_snapshot_download_factory( tasks=tasks, ) if not hf_dataset: - hf_dataset = hf_dataset_factory( - tasks=tasks, episodes=episodes, fps=info["fps"] - ) + hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"]) def _extract_episode_index_from_path(fpath: str) -> int: path = Path(fpath) if path.suffix == ".parquet" and path.stem.startswith("episode_"): - episode_index = int( - path.stem[len("episode_") :] - ) # 'episode_000000' -> 0 + episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0 return episode_index else: return None @@ -112,9 +108,7 @@ def mock_snapshot_download_factory( for episode_dict in episodes.values(): ep_idx = episode_dict["episode_index"] ep_chunk = ep_idx // info["chunks_size"] - data_path = info["data_path"].format( - episode_chunk=ep_chunk, episode_index=ep_idx - ) + data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx) data_files.append(data_path) all_files.extend(data_files) @@ -129,9 +123,7 @@ def mock_snapshot_download_factory( if rel_path.startswith("data/"): episode_index = _extract_episode_index_from_path(rel_path) if episode_index is not None: - _ = single_episode_parquet_path( - local_dir, episode_index, hf_dataset, info - ) + _ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info) if rel_path == INFO_PATH: _ = info_path(local_dir, info) elif rel_path == STATS_PATH: diff --git a/tests/fixtures/optimizers.py b/tests/fixtures/optimizers.py index 149f0255..65488566 100644 --- a/tests/fixtures/optimizers.py +++ b/tests/fixtures/optimizers.py @@ -35,7 +35,5 @@ def optimizer(model_params): @pytest.fixture def scheduler(optimizer): - config = VQBeTSchedulerConfig( - num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5 - ) + config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5) return config.build(optimizer, num_training_steps=100) diff --git a/tests/motors/mock_dynamixel_sdk.py b/tests/motors/mock_dynamixel_sdk.py index 387ff528..ee399f96 100644 --- a/tests/motors/mock_dynamixel_sdk.py +++ b/tests/motors/mock_dynamixel_sdk.py @@ -80,9 +80,7 @@ class GroupSyncRead: def addParam(self, motor_index): # noqa: N802 # Initialize motor default values if motor_index not in self.packet_handler.data: - self.packet_handler.data[motor_index] = get_default_motor_values( - motor_index - ) + self.packet_handler.data[motor_index] = get_default_motor_values(motor_index) def txRxPacket(self): # noqa: N802 return COMM_SUCCESS diff --git a/tests/motors/mock_scservo_sdk.py b/tests/motors/mock_scservo_sdk.py index be6be756..37f6d0d5 100644 --- a/tests/motors/mock_scservo_sdk.py +++ b/tests/motors/mock_scservo_sdk.py @@ -91,9 +91,7 @@ class GroupSyncRead: def addParam(self, motor_index): # noqa: N802 # Initialize motor default values if motor_index not in self.packet_handler.data: - self.packet_handler.data[motor_index] = get_default_motor_values( - motor_index - ) + self.packet_handler.data[motor_index] = get_default_motor_values(motor_index) def txRxPacket(self): # noqa: N802 return COMM_SUCCESS diff --git a/tests/motors/test_motors.py b/tests/motors/test_motors.py index c8013953..6067dece 100644 --- a/tests/motors/test_motors.py +++ b/tests/motors/test_motors.py @@ -79,9 +79,7 @@ def test_configure_motors_all_ids_1(request, motor_type, mock): else: raise ValueError(motor_type) - input( - "Are you sure you want to re-configure the motors? Press enter to continue..." - ) + input("Are you sure you want to re-configure the motors? Press enter to continue...") # This test expect the configuration was already correct. motors_bus = make_motors_bus(motor_type, mock=mock) motors_bus.connect() diff --git a/tests/optim/test_schedulers.py b/tests/optim/test_schedulers.py index e51558cd..17637663 100644 --- a/tests/optim/test_schedulers.py +++ b/tests/optim/test_schedulers.py @@ -43,9 +43,7 @@ def test_diffuser_scheduler(optimizer): def test_vqbet_scheduler(optimizer): - config = VQBeTSchedulerConfig( - num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5 - ) + config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5) scheduler = config.build(optimizer, num_training_steps=100) assert isinstance(scheduler, LambdaLR) diff --git a/tests/policies/hilserl/classifier/check_hiserl_reward_classifier.py b/tests/policies/hilserl/classifier/check_hiserl_reward_classifier.py index 84b96b6d..846899a4 100644 --- a/tests/policies/hilserl/classifier/check_hiserl_reward_classifier.py +++ b/tests/policies/hilserl/classifier/check_hiserl_reward_classifier.py @@ -46,9 +46,7 @@ def train_evaluate_multiclass_classifier(): logging.info( f"Start multiclass classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}" ) - multiclass_config = ClassifierConfig( - model_name="microsoft/resnet-18", device=DEVICE, num_classes=10 - ) + multiclass_config = ClassifierConfig(model_name="microsoft/resnet-18", device=DEVICE, num_classes=10) multiclass_classifier = Classifier(multiclass_config) trainset = CIFAR10(root="data", train=True, download=True, transform=ToTensor()) @@ -119,18 +117,10 @@ def train_evaluate_multiclass_classifier(): test_probs = torch.stack(test_probs) accuracy = Accuracy(task="multiclass", num_classes=multiclass_num_classes) - precision = Precision( - task="multiclass", average="weighted", num_classes=multiclass_num_classes - ) - recall = Recall( - task="multiclass", average="weighted", num_classes=multiclass_num_classes - ) - f1 = F1Score( - task="multiclass", average="weighted", num_classes=multiclass_num_classes - ) - auroc = AUROC( - task="multiclass", num_classes=multiclass_num_classes, average="weighted" - ) + precision = Precision(task="multiclass", average="weighted", num_classes=multiclass_num_classes) + recall = Recall(task="multiclass", average="weighted", num_classes=multiclass_num_classes) + f1 = F1Score(task="multiclass", average="weighted", num_classes=multiclass_num_classes) + auroc = AUROC(task="multiclass", num_classes=multiclass_num_classes, average="weighted") # Calculate metrics acc = accuracy(test_predictions, test_labels) @@ -159,28 +149,18 @@ def train_evaluate_binary_classifier(): new_label = float(1.0) if label == target_class else float(0.0) new_targets.append(new_label) - dataset.targets = ( - new_targets # Replace the original labels with the binary ones - ) + dataset.targets = new_targets # Replace the original labels with the binary ones return dataset - binary_train_dataset = CIFAR10( - root="data", train=True, download=True, transform=ToTensor() - ) - binary_test_dataset = CIFAR10( - root="data", train=False, download=True, transform=ToTensor() - ) + binary_train_dataset = CIFAR10(root="data", train=True, download=True, transform=ToTensor()) + binary_test_dataset = CIFAR10(root="data", train=False, download=True, transform=ToTensor()) # Apply one-vs-rest labeling binary_train_dataset = one_vs_rest(binary_train_dataset, target_binary_class) binary_test_dataset = one_vs_rest(binary_test_dataset, target_binary_class) - binary_trainloader = DataLoader( - binary_train_dataset, batch_size=BATCH_SIZE, shuffle=True - ) - binary_testloader = DataLoader( - binary_test_dataset, batch_size=BATCH_SIZE, shuffle=False - ) + binary_trainloader = DataLoader(binary_train_dataset, batch_size=BATCH_SIZE, shuffle=True) + binary_testloader = DataLoader(binary_test_dataset, batch_size=BATCH_SIZE, shuffle=False) binary_epoch = 1 diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index eb1c0003..2d530379 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -196,13 +196,9 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs): # Test updating the policy (and test that it does not mutate the batch) batch_ = deepcopy(batch) policy.forward(batch) - assert set(batch) == set(batch_), ( - "Batch keys are not the same after a forward pass." - ) + assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass." assert all( - torch.equal(batch[k], batch_[k]) - if isinstance(batch[k], torch.Tensor) - else batch[k] == batch_[k] + torch.equal(batch[k], batch_[k]) if isinstance(batch[k], torch.Tensor) else batch[k] == batch_[k] for k in batch ), "Batch values are not the same after a forward pass." @@ -214,9 +210,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs): observation = preprocess_observation(observation) # send observation to device/gpu - observation = { - key: observation[key].to(DEVICE, non_blocking=True) for key in observation - } + observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation} # get the next action for the environment (also check that the observation batch is not modified) observation_ = deepcopy(observation) @@ -241,12 +235,8 @@ def test_act_backbone_lr(): cfg = TrainPipelineConfig( # TODO(rcadene, aliberts): remove dataset download - dataset=DatasetConfig( - repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0] - ), - policy=make_policy_config( - "act", optimizer_lr=0.01, optimizer_lr_backbone=0.001 - ), + dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]), + policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001), ) cfg.validate() # Needed for auto-setting some parameters @@ -269,9 +259,7 @@ def test_policy_defaults(dummy_dataset_metadata, policy_name: str): policy_cls = get_policy_class(policy_name) policy_cfg = make_policy_config(policy_name) features = dataset_to_policy_features(dummy_dataset_metadata.features) - policy_cfg.output_features = { - key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION - } + policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} policy_cfg.input_features = { key: ft for key, ft in features.items() if key not in policy_cfg.output_features } @@ -283,9 +271,7 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name: policy_cls = get_policy_class(policy_name) policy_cfg = make_policy_config(policy_name) features = dataset_to_policy_features(dummy_dataset_metadata.features) - policy_cfg.output_features = { - key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION - } + policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} policy_cfg.input_features = { key: ft for key, ft in features.items() if key not in policy_cfg.output_features } @@ -294,9 +280,7 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name: save_dir = tmp_path / f"test_save_and_load_pretrained_{policy_cls.__name__}" policy.save_pretrained(save_dir) loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg) - torch.testing.assert_close( - list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0 - ) + torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0) @pytest.mark.parametrize("insert_temporal_dim", [False, True]) @@ -436,9 +420,7 @@ def test_normalize(insert_temporal_dim): # pass if it's run on another platform due to floating point errors @require_x86_64_kernel @require_cpu -def test_backward_compatibility( - ds_repo_id: str, policy_name: str, policy_kwargs: dict, file_name_extra: str -): +def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs: dict, file_name_extra: str): """ NOTE: If this test does not pass, and you have intentionally changed something in the policy: 1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should @@ -452,17 +434,13 @@ def test_backward_compatibility( 6. Remember to stage and commit the resulting changes to `tests/artifacts`. """ ds_name = ds_repo_id.split("/")[-1] - artifact_dir = ( - Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}" - ) + artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}" saved_output_dict = load_file(artifact_dir / "output_dict.safetensors") saved_grad_stats = load_file(artifact_dir / "grad_stats.safetensors") saved_param_stats = load_file(artifact_dir / "param_stats.safetensors") saved_actions = load_file(artifact_dir / "actions.safetensors") - output_dict, grad_stats, param_stats, actions = get_policy_stats( - ds_repo_id, policy_name, policy_kwargs - ) + output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs) for key in saved_output_dict: torch.testing.assert_close(output_dict[key], saved_output_dict[key]) @@ -471,12 +449,8 @@ def test_backward_compatibility( for key in saved_param_stats: torch.testing.assert_close(param_stats[key], saved_param_stats[key]) for key in saved_actions: - rtol, atol = ( - (2e-3, 5e-6) if policy_name == "diffusion" else (None, None) - ) # HACK - torch.testing.assert_close( - actions[key], saved_actions[key], rtol=rtol, atol=atol - ) + rtol, atol = (2e-3, 5e-6) if policy_name == "diffusion" else (None, None) # HACK + torch.testing.assert_close(actions[key], saved_actions[key], rtol=rtol, atol=atol) def test_act_temporal_ensembler(): @@ -502,9 +476,7 @@ def test_act_temporal_ensembler(): batch_size = batch_seq.shape[0] # Exponential weighting (normalized). Unsqueeze once to match the position of the `episode_length` # dimension of `batch_seq`. - weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze( - -1 - ) + weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(-1) # Simulate stepping through a rollout and computing a batch of actions with model on each step. for i in range(episode_length): @@ -527,8 +499,7 @@ def test_act_temporal_ensembler(): episode_step_indices = torch.arange(i + 1)[-len(chunk_indices) :] seq_slice = batch_seq[:, episode_step_indices, chunk_indices] offline_avg = ( - einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum") - / weights[: i + 1].sum() + einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum") / weights[: i + 1].sum() ) # Sanity check. The average should be between the extrema. assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg) diff --git a/tests/robots/test_control_robot.py b/tests/robots/test_control_robot.py index f8263920..04325e8d 100644 --- a/tests/robots/test_control_robot.py +++ b/tests/robots/test_control_robot.py @@ -179,9 +179,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock): assert dataset.meta.total_episodes == 2 assert len(dataset) == 2 - replay_cfg = ReplayControlConfig( - episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False - ) + replay_cfg = ReplayControlConfig(episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False) replay(robot, replay_cfg) policy_cfg = ACTConfig() @@ -336,12 +334,8 @@ def test_record_with_event_rerecord_episode(tmp_path, request, robot_type, mock) ) dataset = record(robot, rec_cfg) - assert not mock_events["rerecord_episode"], ( - "`rerecord_episode` wasn't properly reset to False" - ) - assert not mock_events["exit_early"], ( - "`exit_early` wasn't properly reset to False" - ) + assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False" + assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False" assert len(dataset) == 1, "`dataset` should contain only 1 frame" @@ -391,9 +385,7 @@ def test_record_with_event_exit_early(tmp_path, request, robot_type, mock): dataset = record(robot, rec_cfg) - assert not mock_events["exit_early"], ( - "`exit_early` wasn't properly reset to False" - ) + assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False" assert len(dataset) == 1, "`dataset` should contain only 1 frame" @@ -402,9 +394,7 @@ def test_record_with_event_exit_early(tmp_path, request, robot_type, mock): [("koch", True, 0), ("koch", True, 1)], ) @require_robot -def test_record_with_event_stop_recording( - tmp_path, request, robot_type, mock, num_image_writer_processes -): +def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, num_image_writer_processes): robot_kwargs = {"robot_type": robot_type, "mock": mock} if mock: @@ -450,7 +440,5 @@ def test_record_with_event_stop_recording( dataset = record(robot, rec_cfg) - assert not mock_events["exit_early"], ( - "`exit_early` wasn't properly reset to False" - ) + assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False" assert len(dataset) == 1, "`dataset` should contain only 1 frame" diff --git a/tests/robots/test_robots.py b/tests/robots/test_robots.py index ba676bc3..08da0608 100644 --- a/tests/robots/test_robots.py +++ b/tests/robots/test_robots.py @@ -108,9 +108,7 @@ def test_robot(tmp_path, request, robot_type, mock): assert "observation.state" in observation assert isinstance(observation["observation.state"], torch.Tensor) assert observation["observation.state"].ndim == 1 - dim_state = sum( - len(robot.follower_arms[name].motors) for name in robot.follower_arms - ) + dim_state = sum(len(robot.follower_arms[name].motors) for name in robot.follower_arms) assert observation["observation.state"].shape[0] == dim_state # Cameras for name in robot.cameras: @@ -121,9 +119,7 @@ def test_robot(tmp_path, request, robot_type, mock): assert "action" in action assert isinstance(action["action"], torch.Tensor) assert action["action"].ndim == 1 - dim_action = sum( - len(robot.follower_arms[name].motors) for name in robot.follower_arms - ) + dim_action = sum(len(robot.follower_arms[name].motors) for name in robot.follower_arms) assert action["action"].shape[0] == dim_action # TODO(rcadene): test if observation and action data are returned as expected @@ -134,9 +130,7 @@ def test_robot(tmp_path, request, robot_type, mock): if "image" in name: # TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames continue - torch.testing.assert_close( - captured_observation[name], observation[name], rtol=1e-4, atol=1 - ) + torch.testing.assert_close(captured_observation[name], observation[name], rtol=1e-4, atol=1) assert captured_observation[name].shape == observation[name].shape # Test send_action can run diff --git a/tests/test_train_hilserl_classifier.py b/tests/test_train_hilserl_classifier.py index 837eebc8..58344913 100644 --- a/tests/test_train_hilserl_classifier.py +++ b/tests/test_train_hilserl_classifier.py @@ -69,9 +69,7 @@ def test_create_balanced_sampler(): labels = [item["label"] for item in data] class_counts = torch.tensor([labels.count(0), labels.count(1)], dtype=torch.float32) class_weights = 1.0 / class_counts - expected_weights = torch.tensor( - [class_weights[label] for label in labels], dtype=torch.float32 - ) + expected_weights = torch.tensor([class_weights[label] for label in labels], dtype=torch.float32) # Test that the weights are correct assert torch.allclose(weights, expected_weights) @@ -224,16 +222,10 @@ def test_resume_function( ): # Initialize Hydra test_file_dir = os.path.dirname(os.path.abspath(__file__)) - config_dir = os.path.abspath( - os.path.join(test_file_dir, "..", "lerobot", "configs", "policy") - ) - assert os.path.exists(config_dir), ( - f"Config directory does not exist at {config_dir}" - ) + config_dir = os.path.abspath(os.path.join(test_file_dir, "..", "lerobot", "configs", "policy")) + assert os.path.exists(config_dir), f"Config directory does not exist at {config_dir}" - with initialize_config_dir( - config_dir=config_dir, job_name="test_app", version_base="1.2" - ): + with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"): cfg = compose( config_name="hilserl_classifier", overrides=[ @@ -258,9 +250,7 @@ def test_resume_function( mock_init_hydra_config.return_value = cfg # Mock dataset - dataset = MockDataset( - [{"image": torch.rand(3, 224, 224), "label": i % 2} for i in range(10)] - ) + dataset = MockDataset([{"image": torch.rand(3, 224, 224), "label": i % 2} for i in range(10)]) mock_dataset.return_value = dataset # Mock checkpoint handling diff --git a/tests/utils.py b/tests/utils.py index 23b297cb..a3eae555 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -31,11 +31,7 @@ from lerobot.common.robot_devices.motors.utils import ( ) from lerobot.common.utils.import_utils import is_package_available -DEVICE = ( - os.environ.get("LEROBOT_TEST_DEVICE", "cuda") - if torch.cuda.is_available() - else "cpu" -) +DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu" TEST_ROBOT_TYPES = [] for robot_type in available_robots: @@ -51,13 +47,9 @@ for motor_type in available_motors: # Camera indices used for connecting physical cameras OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0)) -INTELREALSENSE_SERIAL_NUMBER = int( - os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614) -) +INTELREALSENSE_SERIAL_NUMBER = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614)) -DYNAMIXEL_PORT = os.environ.get( - "LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081" -) +DYNAMIXEL_PORT = os.environ.get("LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081") DYNAMIXEL_MOTORS = { "shoulder_pan": [1, "xl430-w250"], "shoulder_lift": [2, "xl430-w250"], @@ -67,9 +59,7 @@ DYNAMIXEL_MOTORS = { "gripper": [6, "xl330-m288"], } -FEETECH_PORT = os.environ.get( - "LEROBOT_TEST_FEETECH_PORT", "/dev/tty.usbmodem585A0080971" -) +FEETECH_PORT = os.environ.get("LEROBOT_TEST_FEETECH_PORT", "/dev/tty.usbmodem585A0080971") FEETECH_MOTORS = { "shoulder_pan": [1, "sts3215"], "shoulder_lift": [2, "sts3215"], @@ -168,13 +158,9 @@ def require_package_arg(func): if "required_packages" in arg_names: # Get the index of 'required_packages' and retrieve the value from args index = arg_names.index("required_packages") - required_packages = ( - args[index] if len(args) > index else kwargs.get("required_packages") - ) + required_packages = args[index] if len(args) > index else kwargs.get("required_packages") else: - raise ValueError( - "Function does not have 'required_packages' as an argument." - ) + raise ValueError("Function does not have 'required_packages' as an argument.") if required_packages is None: return func(*args, **kwargs) @@ -231,17 +217,11 @@ def require_robot(func): mock = kwargs.get("mock") if robot_type is None: - raise ValueError( - "The 'robot_type' must be an argument of the test function." - ) + raise ValueError("The 'robot_type' must be an argument of the test function.") if request is None: - raise ValueError( - "The 'request' fixture must be an argument of the test function." - ) + raise ValueError("The 'request' fixture must be an argument of the test function.") if mock is None: - raise ValueError( - "The 'mock' variable must be an argument of the test function." - ) + raise ValueError("The 'mock' variable must be an argument of the test function.") # Run test with a real robot. Skip test if robot connection fails. if not mock and not request.getfixturevalue("is_robot_available"): @@ -261,17 +241,11 @@ def require_camera(func): mock = kwargs.get("mock") if request is None: - raise ValueError( - "The 'request' fixture must be an argument of the test function." - ) + raise ValueError("The 'request' fixture must be an argument of the test function.") if camera_type is None: - raise ValueError( - "The 'camera_type' must be an argument of the test function." - ) + raise ValueError("The 'camera_type' must be an argument of the test function.") if mock is None: - raise ValueError( - "The 'mock' variable must be an argument of the test function." - ) + raise ValueError("The 'mock' variable must be an argument of the test function.") if not mock and not request.getfixturevalue("is_camera_available"): pytest.skip(f"A {camera_type} camera is not available.") @@ -290,17 +264,11 @@ def require_motor(func): mock = kwargs.get("mock") if request is None: - raise ValueError( - "The 'request' fixture must be an argument of the test function." - ) + raise ValueError("The 'request' fixture must be an argument of the test function.") if motor_type is None: - raise ValueError( - "The 'motor_type' must be an argument of the test function." - ) + raise ValueError("The 'motor_type' must be an argument of the test function.") if mock is None: - raise ValueError( - "The 'mock' variable must be an argument of the test function." - ) + raise ValueError("The 'mock' variable must be an argument of the test function.") if not mock and not request.getfixturevalue("is_motor_available"): pytest.skip(f"A {motor_type} motor is not available.") diff --git a/tests/utils/test_logging_utils.py b/tests/utils/test_logging_utils.py index f57a9eb7..4764bebf 100644 --- a/tests/utils/test_logging_utils.py +++ b/tests/utils/test_logging_utils.py @@ -91,9 +91,7 @@ def test_metrics_tracker_step(mock_metrics): def test_metrics_tracker_getattr(mock_metrics): - tracker = MetricsTracker( - batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics - ) + tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics) assert tracker.loss == mock_metrics["loss"] assert tracker.accuracy == mock_metrics["accuracy"] with pytest.raises(AttributeError): @@ -101,17 +99,13 @@ def test_metrics_tracker_getattr(mock_metrics): def test_metrics_tracker_setattr(mock_metrics): - tracker = MetricsTracker( - batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics - ) + tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics) tracker.loss = 2.0 assert tracker.loss.val == 2.0 def test_metrics_tracker_str(mock_metrics): - tracker = MetricsTracker( - batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics - ) + tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics) tracker.loss.update(3.456, 1) tracker.accuracy.update(0.876, 1) output = str(tracker) @@ -120,9 +114,7 @@ def test_metrics_tracker_str(mock_metrics): def test_metrics_tracker_to_dict(mock_metrics): - tracker = MetricsTracker( - batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics - ) + tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics) tracker.loss.update(5, 2) metrics_dict = tracker.to_dict() assert isinstance(metrics_dict, dict) @@ -131,9 +123,7 @@ def test_metrics_tracker_to_dict(mock_metrics): def test_metrics_tracker_reset_averages(mock_metrics): - tracker = MetricsTracker( - batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics - ) + tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics) tracker.loss.update(10, 3) tracker.accuracy.update(0.95, 5) tracker.reset_averages() diff --git a/tests/utils/test_random_utils.py b/tests/utils/test_random_utils.py index 01df7341..daf08a89 100644 --- a/tests/utils/test_random_utils.py +++ b/tests/utils/test_random_utils.py @@ -118,9 +118,5 @@ def test_seeded_context(fixed_seed): seeded_val2 = (random.random(), np.random.rand(), torch.rand(1).item()) assert seeded_val1 == seeded_val2 - assert all( - a != b for a, b in zip(val1, seeded_val1, strict=True) - ) # changed inside the context - assert all( - a != b for a, b in zip(val2, seeded_val2, strict=True) - ) # changed again after exiting + assert all(a != b for a, b in zip(val1, seeded_val1, strict=True)) # changed inside the context + assert all(a != b for a, b in zip(val2, seeded_val2, strict=True)) # changed again after exiting diff --git a/tests/utils/test_train_utils.py b/tests/utils/test_train_utils.py index 1fc63eef..b78f6e49 100644 --- a/tests/utils/test_train_utils.py +++ b/tests/utils/test_train_utils.py @@ -91,9 +91,7 @@ def test_save_training_state(tmp_path, optimizer, scheduler): def test_save_load_training_state(tmp_path, optimizer, scheduler): save_training_state(tmp_path, 10, optimizer, scheduler) - loaded_step, loaded_optimizer, loaded_scheduler = load_training_state( - tmp_path, optimizer, scheduler - ) + loaded_step, loaded_optimizer, loaded_scheduler = load_training_state(tmp_path, optimizer, scheduler) assert loaded_step == 10 assert loaded_optimizer is optimizer assert loaded_scheduler is scheduler