diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index 6c3af54e..ab784193 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -22,7 +22,10 @@ from pathlib import Path import torch -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +from lerobot.common.datasets.lerobot_dataset import ( + LeRobotDataset, + LeRobotDatasetMetadata, +) from lerobot.common.datasets.utils import dataset_to_policy_features from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy @@ -48,12 +51,18 @@ def main(): # - dataset stats: for normalization and denormalization of input/outputs dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht") features = dataset_to_policy_features(dataset_metadata.features) - output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} - input_features = {key: ft for key, ft in features.items() if key not in output_features} + output_features = { + key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION + } + input_features = { + key: ft for key, ft in features.items() if key not in output_features + } # Policies are initialized with a configuration class, in this case `DiffusionConfig`. For this example, # we'll just use the defaults and so no arguments other than input/output features need to be passed. - cfg = DiffusionConfig(input_features=input_features, output_features=output_features) + cfg = DiffusionConfig( + input_features=input_features, output_features=output_features + ) # We can now instantiate our policy with this config and the dataset stats. policy = DiffusionPolicy(cfg, dataset_stats=dataset_metadata.stats) @@ -63,8 +72,12 @@ def main(): # Another policy-dataset interaction is with the delta_timestamps. Each policy expects a given number frames # which can differ for inputs, outputs and rewards (if there are some). delta_timestamps = { - "observation.image": [i / dataset_metadata.fps for i in cfg.observation_delta_indices], - "observation.state": [i / dataset_metadata.fps for i in cfg.observation_delta_indices], + "observation.image": [ + i / dataset_metadata.fps for i in cfg.observation_delta_indices + ], + "observation.state": [ + i / dataset_metadata.fps for i in cfg.observation_delta_indices + ], "action": [i / dataset_metadata.fps for i in cfg.action_delta_indices], } @@ -77,7 +90,24 @@ def main(): # Load the previous action (-0.1), the next action to be executed (0.0), # and 14 future actions with a 0.1 seconds spacing. All these actions will be # used to supervise the policy. - "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4], + "action": [ + -0.1, + 0.0, + 0.1, + 0.2, + 0.3, + 0.4, + 0.5, + 0.6, + 0.7, + 0.8, + 0.9, + 1.0, + 1.1, + 1.2, + 1.3, + 1.4, + ], } # We can then instantiate the dataset with these delta_timestamps configuration. @@ -99,7 +129,10 @@ def main(): done = False while not done: for batch in dataloader: - batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} + batch = { + k: (v.to(device) if isinstance(v, torch.Tensor) else v) + for k, v in batch.items() + } loss, _ = policy.forward(batch) loss.backward() optimizer.step() diff --git a/examples/advanced/2_calculate_validation_loss.py b/examples/advanced/2_calculate_validation_loss.py index 80c9f3a8..e7eb7fb4 100644 --- a/examples/advanced/2_calculate_validation_loss.py +++ b/examples/advanced/2_calculate_validation_loss.py @@ -54,7 +54,24 @@ def main(): # Load the previous action (-0.1), the next action to be executed (0.0), # and 14 future actions with a 0.1 seconds spacing. All these actions will be # used to calculate the loss. - "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4], + "action": [ + -0.1, + 0.0, + 0.1, + 0.2, + 0.3, + 0.4, + 0.5, + 0.6, + 0.7, + 0.8, + 0.9, + 1.0, + 1.1, + 1.2, + 1.3, + 1.4, + ], } # Load the last 10% of episodes of the dataset as a validation set. @@ -73,7 +90,9 @@ def main(): train_dataset = LeRobotDataset( "lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps ) - val_dataset = LeRobotDataset("lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps) + val_dataset = LeRobotDataset( + "lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps + ) print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}") print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}") diff --git a/lerobot/common/datasets/compute_stats.py b/lerobot/common/datasets/compute_stats.py index 1149ec83..0202b0e2 100644 --- a/lerobot/common/datasets/compute_stats.py +++ b/lerobot/common/datasets/compute_stats.py @@ -19,7 +19,10 @@ from lerobot.common.datasets.utils import load_image_as_numpy def estimate_num_samples( - dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75 + dataset_len: int, + min_num_samples: int = 100, + max_num_samples: int = 10_000, + power: float = 0.75, ) -> int: """Heuristic to estimate the number of samples based on dataset size. The power controls the sample growth relative to dataset size. @@ -43,14 +46,18 @@ def sample_indices(data_len: int) -> list[int]: return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist() -def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300): +def auto_downsample_height_width( + img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300 +): _, height, width = img.shape if max(width, height) < max_size_threshold: # no downsampling needed return img - downsample_factor = int(width / target_size) if width > height else int(height / target_size) + downsample_factor = ( + int(width / target_size) if width > height else int(height / target_size) + ) return img[:, ::downsample_factor, ::downsample_factor] @@ -72,7 +79,9 @@ def sample_images(image_paths: list[str]) -> np.ndarray: return images -def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]: +def get_feature_stats( + array: np.ndarray, axis: tuple, keepdims: bool +) -> dict[str, np.ndarray]: return { "min": np.min(array, axis=axis, keepdims=keepdims), "max": np.max(array, axis=axis, keepdims=keepdims), @@ -82,7 +91,9 @@ def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[st } -def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict: +def compute_episode_stats( + episode_data: dict[str, list[str] | np.ndarray], features: dict +) -> dict: ep_stats = {} for key, data in episode_data.items(): if features[key]["dtype"] == "string": @@ -96,12 +107,15 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu axes_to_reduce = 0 # compute stats over the first axis keepdims = data.ndim == 1 # keep as np.array - ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims) + ep_stats[key] = get_feature_stats( + ep_ft_array, axis=axes_to_reduce, keepdims=keepdims + ) # finally, we normalize and remove batch dim for images if features[key]["dtype"] in ["image", "video"]: ep_stats[key] = { - k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items() + k: v if k == "count" else np.squeeze(v / 255.0, axis=0) + for k, v in ep_stats[key].items() } return ep_stats @@ -116,14 +130,22 @@ def _assert_type_and_shape(stats_list: list[dict[str, dict]]): f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead." ) if v.ndim == 0: - raise ValueError("Number of dimensions must be at least 1, and is 0 instead.") + raise ValueError( + "Number of dimensions must be at least 1, and is 0 instead." + ) if k == "count" and v.shape != (1,): - raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.") + raise ValueError( + f"Shape of 'count' must be (1), but is {v.shape} instead." + ) if "image" in fkey and k != "count" and v.shape != (3, 1, 1): - raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.") + raise ValueError( + f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead." + ) -def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]: +def aggregate_feature_stats( + stats_ft_list: list[dict[str, dict]], +) -> dict[str, dict[str, np.ndarray]]: """Aggregates stats for a single feature.""" means = np.stack([s["mean"] for s in stats_ft_list]) variances = np.stack([s["std"] ** 2 for s in stats_ft_list]) @@ -152,7 +174,9 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d } -def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]: +def aggregate_stats( + stats_list: list[dict[str, dict]], +) -> dict[str, dict[str, np.ndarray]]: """Aggregate stats from multiple compute_stats outputs into a single set of stats. The final stats will have the union of all data keys from each of the stats dicts. diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 38c01b42..f7a12838 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -58,7 +58,9 @@ def resolve_delta_timestamps( if key == "action" and cfg.action_delta_indices is not None: delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices] if key.startswith("observation.") and cfg.observation_delta_indices is not None: - delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices] + delta_timestamps[key] = [ + i / ds_meta.fps for i in cfg.observation_delta_indices + ] if len(delta_timestamps) == 0: delta_timestamps = None @@ -79,7 +81,9 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas LeRobotDataset | MultiLeRobotDataset """ image_transforms = ( - ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None + ImageTransforms(cfg.dataset.image_transforms) + if cfg.dataset.image_transforms.enable + else None ) if isinstance(cfg.dataset.repo_id, str): @@ -113,6 +117,8 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas if cfg.dataset.use_imagenet_stats: for key in dataset.meta.camera_keys: for stats_type, stats in IMAGENET_STATS.items(): - dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32) + dataset.meta.stats[key][stats_type] = torch.tensor( + stats, dtype=torch.float32 + ) return dataset diff --git a/lerobot/common/datasets/image_writer.py b/lerobot/common/datasets/image_writer.py index 2dd685eb..81e5bae9 100644 --- a/lerobot/common/datasets/image_writer.py +++ b/lerobot/common/datasets/image_writer.py @@ -38,10 +38,14 @@ def safe_stop_image_writer(func): return wrapper -def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image: +def image_array_to_pil_image( + image_array: np.ndarray, range_check: bool = True +) -> PIL.Image.Image: # TODO(aliberts): handle 1 channel and 4 for depth images if image_array.ndim != 3: - raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.") + raise ValueError( + f"The array has {image_array.ndim} dimensions, but 3 is expected for an image." + ) if image_array.shape[0] == 3: # Transpose from pytorch convention (C, H, W) to (H, W, C) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 61fd6cc5..e4fa422d 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -108,7 +108,9 @@ class LeRobotDatasetMetadata: self.episodes = load_episodes(self.root) if self._version < packaging.version.parse("v2.1"): self.stats = load_stats(self.root) - self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes) + self.episodes_stats = backward_compatible_episodes_stats( + self.stats, self.episodes + ) else: self.episodes_stats = load_episodes_stats(self.root) self.stats = aggregate_stats(list(self.episodes_stats.values())) @@ -238,7 +240,9 @@ class LeRobotDatasetMetadata: Given a task in natural language, add it to the dictionary of tasks. """ if task in self.task_to_task_index: - raise ValueError(f"The task '{task}' already exists and can't be added twice.") + raise ValueError( + f"The task '{task}' already exists and can't be added twice." + ) task_index = self.info["total_tasks"] self.task_to_task_index[task] = task_index @@ -281,7 +285,11 @@ class LeRobotDatasetMetadata: write_episode(episode_dict, self.root) self.episodes_stats[episode_index] = episode_stats - self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats + self.stats = ( + aggregate_stats([self.stats, episode_stats]) + if self.stats + else episode_stats + ) write_episode_stats(episode_index, episode_stats, self.root) def update_video_info(self) -> None: @@ -345,13 +353,17 @@ class LeRobotDatasetMetadata: # as this would break the dict flattening in the stats computation, which uses '/' as separator for key in features: if "/" in key: - raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.") + raise ValueError( + f"Feature names should not contain '/'. Found '/' in feature '{key}'." + ) features = {**features, **DEFAULT_FEATURES} obj.tasks, obj.task_to_task_index = {}, {} obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {} - obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos) + obj.info = create_empty_dataset_info( + CODEBASE_VERSION, fps, robot_type, features, use_videos + ) if len(obj.video_keys) > 0 and not use_videos: raise ValueError() write_json(obj.info, obj.root / INFO_PATH) @@ -482,7 +494,9 @@ class LeRobotDataset(torch.utils.data.Dataset): self.episodes = episodes self.tolerance_s = tolerance_s self.revision = revision if revision else CODEBASE_VERSION - self.video_backend = video_backend if video_backend else get_safe_default_codec() + self.video_backend = ( + video_backend if video_backend else get_safe_default_codec() + ) self.delta_indices = None # Unused attributes @@ -495,28 +509,39 @@ class LeRobotDataset(torch.utils.data.Dataset): self.meta = LeRobotDatasetMetadata( self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync ) - if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"): - episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes] + if self.episodes is not None and self.meta._version >= packaging.version.parse( + "v2.1" + ): + episodes_stats = [ + self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes + ] self.stats = aggregate_stats(episodes_stats) # Load actual data try: if force_cache_sync: raise FileNotFoundError - assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths()) + assert all( + (self.root / fpath).is_file() + for fpath in self.get_episodes_file_paths() + ) self.hf_dataset = self.load_hf_dataset() except (AssertionError, FileNotFoundError, NotADirectoryError): self.revision = get_safe_version(self.repo_id, self.revision) self.download_episodes(download_videos) self.hf_dataset = self.load_hf_dataset() - self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) + self.episode_data_index = get_episode_data_index( + self.meta.episodes, self.episodes + ) # Check timestamps timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy() episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy() ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()} - check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s) + check_timestamps_sync( + timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s + ) # Setup delta_indices if self.delta_timestamps is not None: @@ -568,7 +593,9 @@ class LeRobotDataset(torch.utils.data.Dataset): else: hub_api.upload_folder(**upload_kwargs) - if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch): + if not hub_api.file_exists( + self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch + ): card = create_lerobot_dataset_card( tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs ) @@ -576,8 +603,12 @@ class LeRobotDataset(torch.utils.data.Dataset): if tag_version: with contextlib.suppress(RevisionNotFoundError): - hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset") - hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") + hub_api.delete_tag( + self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset" + ) + hub_api.create_tag( + self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset" + ) def pull_from_repo( self, @@ -609,7 +640,11 @@ class LeRobotDataset(torch.utils.data.Dataset): self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns) def get_episodes_file_paths(self) -> list[Path]: - episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes)) + episodes = ( + self.episodes + if self.episodes is not None + else list(range(self.meta.total_episodes)) + ) fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes] if len(self.meta.video_keys) > 0: video_files = [ @@ -640,7 +675,9 @@ class LeRobotDataset(torch.utils.data.Dataset): def create_hf_dataset(self) -> datasets.Dataset: features = get_hf_features_from_features(self.features) ft_dict = {col: [] for col in features} - hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train") + hf_dataset = datasets.Dataset.from_dict( + ft_dict, features=features, split="train" + ) # TODO(aliberts): hf_dataset.set_format("torch") hf_dataset.set_transform(hf_transform_to_torch) @@ -726,7 +763,9 @@ class LeRobotDataset(torch.utils.data.Dataset): if key not in self.meta.video_keys } - def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]: + def _query_videos( + self, query_timestamps: dict[str, list[float]], ep_idx: int + ) -> dict[str, torch.Tensor]: """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault. This probably happens because a memory reference to the video loader is created in @@ -735,7 +774,9 @@ class LeRobotDataset(torch.utils.data.Dataset): item = {} for vid_key, query_ts in query_timestamps.items(): video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) - frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend) + frames = decode_video_frames( + video_path, query_ts, self.tolerance_s, self.video_backend + ) item[vid_key] = frames.squeeze(0) return item @@ -789,7 +830,9 @@ class LeRobotDataset(torch.utils.data.Dataset): ) def create_episode_buffer(self, episode_index: int | None = None) -> dict: - current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index + current_ep_idx = ( + self.meta.total_episodes if episode_index is None else episode_index + ) ep_buffer = {} # size and task are special cases that are not in self.features ep_buffer["size"] = 0 @@ -887,7 +930,9 @@ class LeRobotDataset(torch.utils.data.Dataset): episode_tasks = list(set(tasks)) episode_index = episode_buffer["episode_index"] - episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length) + episode_buffer["index"] = np.arange( + self.meta.total_frames, self.meta.total_frames + episode_length + ) episode_buffer["episode_index"] = np.full((episode_length,), episode_index) # Add new tasks to the tasks dictionary @@ -897,12 +942,17 @@ class LeRobotDataset(torch.utils.data.Dataset): self.meta.add_task(task) # Given tasks in natural language, find their corresponding task indices - episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks]) + episode_buffer["task_index"] = np.array( + [self.meta.get_task_index(task) for task in tasks] + ) for key, ft in self.features.items(): # index, episode_index, task_index are already processed above, and image and video # are processed separately by storing image path and frame info as meta data - if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]: + if key in ["index", "episode_index", "task_index"] or ft["dtype"] in [ + "image", + "video", + ]: continue episode_buffer[key] = np.stack(episode_buffer[key]) @@ -944,7 +994,9 @@ class LeRobotDataset(torch.utils.data.Dataset): def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None: episode_dict = {key: episode_buffer[key] for key in self.hf_features} - ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train") + ep_dataset = datasets.Dataset.from_dict( + episode_dict, features=self.hf_features, split="train" + ) ep_dataset = embed_images(ep_dataset) self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset]) self.hf_dataset.set_transform(hf_transform_to_torch) @@ -1063,7 +1115,9 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.delta_timestamps = None obj.delta_indices = None obj.episode_data_index = None - obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() + obj.video_backend = ( + video_backend if video_backend is not None else get_safe_default_codec() + ) return obj @@ -1088,7 +1142,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): super().__init__() self.repo_ids = repo_ids self.root = Path(root) if root else HF_LEROBOT_HOME - self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids} + self.tolerances_s = ( + tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids} + ) # Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which # are handled by this class. self._datasets = [ diff --git a/lerobot/common/datasets/transforms.py b/lerobot/common/datasets/transforms.py index 401120da..2dbd5904 100644 --- a/lerobot/common/datasets/transforms.py +++ b/lerobot/common/datasets/transforms.py @@ -141,12 +141,16 @@ class SharpnessJitter(Transform): return float(sharpness[0]), float(sharpness[1]) def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: - sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item() + sharpness_factor = ( + torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item() + ) return {"sharpness_factor": sharpness_factor} def transform(self, inpt: Any, params: dict[str, Any]) -> Any: sharpness_factor = params["sharpness_factor"] - return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor) + return self._call_kernel( + F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor + ) @dataclass diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 1050f6eb..8171d000 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -135,7 +135,9 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict: elif isinstance(value, (int, float)): serialized_dict[key] = value else: - raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.") + raise NotImplementedError( + f"The value '{value}' of type '{type(value)}' is not supported." + ) return unflatten_dict(serialized_dict) @@ -214,7 +216,10 @@ def write_task(task_index: int, task: dict, local_dir: Path): def load_tasks(local_dir: Path) -> tuple[dict, dict]: tasks = load_jsonlines(local_dir / TASKS_PATH) - tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])} + tasks = { + item["task_index"]: item["task"] + for item in sorted(tasks, key=lambda x: x["task_index"]) + } task_to_task_index = {task: task_index for task_index, task in tasks.items()} return tasks, task_to_task_index @@ -225,13 +230,19 @@ def write_episode(episode: dict, local_dir: Path): def load_episodes(local_dir: Path) -> dict: episodes = load_jsonlines(local_dir / EPISODES_PATH) - return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])} + return { + item["episode_index"]: item + for item in sorted(episodes, key=lambda x: x["episode_index"]) + } def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path): # We wrap episode_stats in a dictionary since `episode_stats["episode_index"]` # is a dictionary of stats and not an integer. - episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)} + episode_stats = { + "episode_index": episode_index, + "stats": serialize_dict(episode_stats), + } append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH) @@ -275,7 +286,9 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): elif first_item is None: pass else: - items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]] + items_dict[key] = [ + x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key] + ] return items_dict @@ -328,7 +341,9 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> Otherwise, will throw a `CompatibilityError`. """ target_version = ( - packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version + packaging.version.parse(version) + if not isinstance(version, packaging.version.Version) + else version ) hub_versions = get_repo_versions(repo_id) @@ -349,12 +364,16 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> return f"v{target_version}" compatibles = [ - v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor + v + for v in hub_versions + if v.major == target_version.major and v.minor <= target_version.minor ] if compatibles: return_version = max(compatibles) if return_version < target_version: - logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}") + logging.warning( + f"Revision {version} for {repo_id} not found, using version v{return_version}" + ) return f"v{return_version}" lower_major = [v for v in hub_versions if v.major < target_version.major] @@ -461,7 +480,9 @@ def create_empty_dataset_info( def get_episode_data_index( episode_dicts: dict[dict], episodes: list[int] | None = None ) -> dict[str, torch.Tensor]: - episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()} + episode_lengths = { + ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items() + } if episodes is not None: episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes} @@ -511,7 +532,9 @@ def check_timestamps_sync( # Mask to ignore differences at the boundaries between episodes mask = np.ones(len(diffs), dtype=bool) - ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode + ignored_diffs = ( + episode_data_index["to"][:-1] - 1 + ) # indices at the end of each episode mask[ignored_diffs] = False filtered_within_tolerance = within_tolerance[mask] @@ -720,14 +743,18 @@ def validate_frame(frame: dict, features: dict): expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"} actual_features = set(frame.keys()) - error_message = validate_features_presence(actual_features, expected_features, optional_features) + error_message = validate_features_presence( + actual_features, expected_features, optional_features + ) if "task" in frame: error_message += validate_feature_string("task", frame["task"]) common_features = actual_features & (expected_features | optional_features) for name in common_features - {"task"}: - error_message += validate_feature_dtype_and_shape(name, features[name], frame[name]) + error_message += validate_feature_dtype_and_shape( + name, features[name], frame[name] + ) if error_message: raise ValueError(error_message) @@ -750,7 +777,9 @@ def validate_features_presence( return error_message -def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str): +def validate_feature_dtype_and_shape( + name: str, feature: dict, value: np.ndarray | PILImage.Image | str +): expected_dtype = feature["dtype"] expected_shape = feature["shape"] if is_valid_numpy_dtype_string(expected_dtype): @@ -760,7 +789,9 @@ def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray elif expected_dtype == "string": return validate_feature_string(name, value) else: - raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.") + raise NotImplementedError( + f"The feature dtype '{expected_dtype}' is not implemented yet." + ) def validate_feature_numpy_array( @@ -782,13 +813,17 @@ def validate_feature_numpy_array( return error_message -def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image): +def validate_feature_image_or_video( + name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image +): # Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads. error_message = "" if isinstance(value, np.ndarray): actual_shape = value.shape c, h, w = expected_shape - if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)): + if len(actual_shape) != 3 or ( + actual_shape != (c, h, w) and actual_shape != (h, w, c) + ): error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n" elif isinstance(value, PILImage.Image): pass @@ -819,7 +854,9 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: ) if episode_buffer["size"] == 0: - raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.") + raise ValueError( + "You must add one or several frames with `add_frame` before calling `add_episode`." + ) buffer_keys = set(episode_buffer.keys()) - {"task", "size"} if not buffer_keys == set(features): diff --git a/lerobot/common/datasets/v21/_remove_language_instruction.py b/lerobot/common/datasets/v21/_remove_language_instruction.py index 643ddd3f..66461c59 100644 --- a/lerobot/common/datasets/v21/_remove_language_instruction.py +++ b/lerobot/common/datasets/v21/_remove_language_instruction.py @@ -35,22 +35,30 @@ def fix_dataset(repo_id: str) -> str: dataset_info = get_dataset_config_info(repo_id, "default") with SuppressWarnings(): - lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True) + lerobot_metadata = LeRobotDatasetMetadata( + repo_id, revision=V20, force_cache_sync=True + ) - meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"} + meta_features = { + key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video" + } parquet_features = set(dataset_info.features) diff_parquet_meta = parquet_features - meta_features diff_meta_parquet = meta_features - parquet_features if diff_parquet_meta: - raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}") + raise ValueError( + f"In parquet not in info.json: {parquet_features - meta_features}" + ) if not diff_meta_parquet: return f"{repo_id}: skipped (no diff)" if diff_meta_parquet: - logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}") + logging.warning( + f"In info.json not in parquet: {meta_features - parquet_features}" + ) assert diff_meta_parquet == {"language_instruction"} lerobot_metadata.features.pop("language_instruction") write_info(lerobot_metadata.info, lerobot_metadata.root) diff --git a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py index 176d16d0..3cd6e52b 100644 --- a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py +++ b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py @@ -37,8 +37,16 @@ import logging from huggingface_hub import HfApi from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset -from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info -from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats +from lerobot.common.datasets.utils import ( + EPISODES_STATS_PATH, + STATS_PATH, + load_stats, + write_info, +) +from lerobot.common.datasets.v21.convert_stats import ( + check_aggregate_stats, + convert_stats, +) V20 = "v2.0" V21 = "v2.1" @@ -79,13 +87,21 @@ def convert_dataset( hub_api = HfApi() if hub_api.file_exists( - repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset" + repo_id=dataset.repo_id, + filename=STATS_PATH, + revision=branch, + repo_type="dataset", ): hub_api.delete_file( - path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset" + path_in_repo=STATS_PATH, + repo_id=dataset.repo_id, + revision=branch, + repo_type="dataset", ) - hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") + hub_api.create_tag( + repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset" + ) if __name__ == "__main__": diff --git a/lerobot/common/datasets/v21/convert_stats.py b/lerobot/common/datasets/v21/convert_stats.py index 4a20b427..a0df5337 100644 --- a/lerobot/common/datasets/v21/convert_stats.py +++ b/lerobot/common/datasets/v21/convert_stats.py @@ -17,12 +17,18 @@ from concurrent.futures import ThreadPoolExecutor, as_completed import numpy as np from tqdm import tqdm -from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices +from lerobot.common.datasets.compute_stats import ( + aggregate_stats, + get_feature_stats, + sample_indices, +) from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.utils import write_episode_stats -def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray: +def sample_episode_video_frames( + dataset: LeRobotDataset, episode_index: int, ft_key: str +) -> np.ndarray: ep_len = dataset.meta.episodes[episode_index]["length"] sampled_indices = sample_indices(ep_len) query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices}) @@ -45,11 +51,14 @@ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int): axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0 keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1 - ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims) + ep_stats[key] = get_feature_stats( + ep_ft_data, axis=axes_to_reduce, keepdims=keepdims + ) if ft["dtype"] in ["image", "video"]: # remove batch dim ep_stats[key] = { - k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items() + k: v if k == "count" else np.squeeze(v, axis=0) + for k, v in ep_stats[key].items() } dataset.meta.episodes_stats[ep_idx] = ep_stats @@ -95,5 +104,9 @@ def check_aggregate_stats( if key in reference_stats and stat in reference_stats[key]: err_msg = f"feature='{key}' stats='{stat}'" np.testing.assert_allclose( - val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg + val, + reference_stats[key][stat], + rtol=rtol, + atol=atol, + err_msg=err_msg, ) diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index 3aaba567..d4f345d3 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -65,7 +65,9 @@ def decode_video_frames( if backend == "torchcodec": return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s) elif backend in ["pyav", "video_reader"]: - return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend) + return decode_video_frames_torchvision( + video_path, timestamps, tolerance_s, backend + ) else: raise ValueError(f"Unsupported video backend: {backend}") diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index cf90048a..69c31684 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -61,10 +61,16 @@ class AlohaEnv(EnvConfig): def __post_init__(self): if self.obs_type == "pixels": - self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3)) + self.features["top"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(480, 640, 3) + ) elif self.obs_type == "pixels_agent_pos": - self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,)) - self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3)) + self.features["agent_pos"] = PolicyFeature( + type=FeatureType.STATE, shape=(14,) + ) + self.features["pixels/top"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(480, 640, 3) + ) @property def gym_kwargs(self) -> dict: @@ -102,9 +108,13 @@ class PushtEnv(EnvConfig): def __post_init__(self): if self.obs_type == "pixels_agent_pos": - self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3)) + self.features["pixels"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(384, 384, 3) + ) elif self.obs_type == "environment_state_agent_pos": - self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,)) + self.features["environment_state"] = PolicyFeature( + type=FeatureType.ENV, shape=(16,) + ) @property def gym_kwargs(self) -> dict: @@ -143,7 +153,9 @@ class XarmEnv(EnvConfig): def __post_init__(self): if self.obs_type == "pixels_agent_pos": - self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,)) + self.features["agent_pos"] = PolicyFeature( + type=FeatureType.STATE, shape=(4,) + ) @property def gym_kwargs(self) -> dict: diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index b0efa71e..9670e456 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -32,7 +32,9 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig: raise ValueError(f"Policy type '{env_type}' is not available.") -def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None: +def make_env( + cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False +) -> gym.vector.VectorEnv | None: """Makes a gym vector environment according to the config. Args: @@ -56,7 +58,9 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g try: importlib.import_module(package_name) except ModuleNotFoundError as e: - print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`") + print( + f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`" + ) raise e gym_handle = f"{package_name}/{cfg.task}" @@ -64,7 +68,10 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g # batched version of the env that returns an observation of shape (b, c) env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv env = env_cls( - [lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)] + [ + lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) + for _ in range(n_envs) + ] ) return env diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py index b44e041b..77716b75 100644 --- a/lerobot/common/envs/utils.py +++ b/lerobot/common/envs/utils.py @@ -46,7 +46,9 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten # sanity check that images are channel last _, h, w, c = img.shape - assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}" + assert c < h and c < w, ( + f"expect channel last images, but instead got {img.shape=}" + ) # sanity check that images are uint8 assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" @@ -79,7 +81,9 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]: for key, ft in env_cfg.features.items(): if ft.type is FeatureType.VISUAL: if len(ft.shape) != 3: - raise ValueError(f"Number of dimensions of {key} != 3 (shape={ft.shape})") + raise ValueError( + f"Number of dimensions of {key} != 3 (shape={ft.shape})" + ) shape = get_channel_first_image_shape(ft.shape) feature = PolicyFeature(type=ft.type, shape=shape) @@ -92,7 +96,9 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]: return policy_features -def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]: +def preprocess_maniskill_observation( + observations: dict[str, np.ndarray], +) -> dict[str, Tensor]: """Convert environment observation to LeRobot format observation. Args: observation: Dictionary of observation batches from a Gym vector environment. diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index b140270b..84038a41 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -250,9 +250,9 @@ class Logger: ) # For the case where the optimizer is a dictionary of optimizers (e.g., sac) if type(training_state["optimizer"]) is dict: - assert set(training_state["optimizer"].keys()) == set( - optimizer.keys() - ), "Optimizer dictionaries do not have the same keys during resume!" + assert set(training_state["optimizer"].keys()) == set(optimizer.keys()), ( + "Optimizer dictionaries do not have the same keys during resume!" + ) for k, v in training_state["optimizer"].items(): optimizer[k].load_state_dict(v) else: diff --git a/lerobot/common/optim/factory.py b/lerobot/common/optim/factory.py index 10ff3df7..0854332b 100644 --- a/lerobot/common/optim/factory.py +++ b/lerobot/common/optim/factory.py @@ -34,7 +34,13 @@ def make_optimizer_and_scheduler( Returns: tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`. """ - params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters() + params = ( + policy.get_optim_params() + if cfg.use_policy_training_preset + else policy.parameters() + ) optimizer = cfg.optimizer.build(params) - lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None + lr_scheduler = ( + cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None + ) return optimizer, lr_scheduler diff --git a/lerobot/common/optim/optimizers.py b/lerobot/common/optim/optimizers.py index 0cf4124c..bb82cf1d 100644 --- a/lerobot/common/optim/optimizers.py +++ b/lerobot/common/optim/optimizers.py @@ -102,7 +102,9 @@ def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> No write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS) -def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer: +def load_optimizer_state( + optimizer: torch.optim.Optimizer, save_dir: Path +) -> torch.optim.Optimizer: current_state_dict = optimizer.state_dict() flat_state = load_file(save_dir / OPTIMIZER_STATE) state = unflatten_dict(flat_state) diff --git a/lerobot/common/optim/schedulers.py b/lerobot/common/optim/schedulers.py index 7e158394..f189a647 100644 --- a/lerobot/common/optim/schedulers.py +++ b/lerobot/common/optim/schedulers.py @@ -36,7 +36,9 @@ class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC): return self.get_choice_name(self.__class__) @abc.abstractmethod - def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler | None: + def build( + self, optimizer: Optimizer, num_training_steps: int + ) -> LRScheduler | None: raise NotImplementedError @@ -49,7 +51,11 @@ class DiffuserSchedulerConfig(LRSchedulerConfig): def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR: from diffusers.optimization import get_scheduler - kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer} + kwargs = { + **asdict(self), + "num_training_steps": num_training_steps, + "optimizer": optimizer, + } return get_scheduler(**kwargs) @@ -71,7 +77,14 @@ class VQBeTSchedulerConfig(LRSchedulerConfig): progress = float(adjusted_step - self.num_warmup_steps) / float( max(1, num_training_steps - self.num_warmup_steps) ) - return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress))) + return max( + 0.0, + 0.5 + * ( + 1.0 + + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress) + ), + ) return LambdaLR(optimizer, lr_lambda, -1) @@ -98,7 +111,9 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig): def cosine_decay_schedule(current_step): step = min(current_step, self.num_decay_steps) - cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps)) + cosine_decay = 0.5 * ( + 1 + math.cos(math.pi * step / self.num_decay_steps) + ) alpha = self.decay_lr / self.peak_lr decayed = (1 - alpha) * cosine_decay + alpha return decayed @@ -117,6 +132,8 @@ def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None: def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler: - state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict()) + state_dict = deserialize_json_into_object( + save_dir / SCHEDULER_STATE, scheduler.state_dict() + ) scheduler.load_state_dict(state_dict) return scheduler diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 7a5819b7..3c5d30b7 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -171,7 +171,9 @@ class ACTConfig(PreTrainedConfig): def validate_features(self) -> None: if not self.image_features and not self.env_state_feature: - raise ValueError("You must provide at least one image or the environment state among the inputs.") + raise ValueError( + "You must provide at least one image or the environment state among the inputs." + ) @property def observation_delta_indices(self) -> None: diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 3dec1584..4ab87107 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -63,7 +63,9 @@ class ACTPolicy(PreTrainedPolicy): config.validate_features() self.config = config - self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_inputs = Normalize( + config.input_features, config.normalization_mapping, dataset_stats + ) self.normalize_targets = Normalize( config.output_features, config.normalization_mapping, dataset_stats ) @@ -120,8 +122,12 @@ class ACTPolicy(PreTrainedPolicy): batch = self.normalize_inputs(batch) if self.config.image_features: - batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = [batch[key] for key in self.config.image_features] + batch = dict( + batch + ) # shallow copy so that adding a key doesn't modify the original + batch["observation.images"] = [ + batch[key] for key in self.config.image_features + ] # If we are doing temporal ensembling, do online updates where we keep track of the number of actions # we are ensembling over. @@ -148,8 +154,12 @@ class ACTPolicy(PreTrainedPolicy): """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) if self.config.image_features: - batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = [batch[key] for key in self.config.image_features] + batch = dict( + batch + ) # shallow copy so that adding a key doesn't modify the original + batch["observation.images"] = [ + batch[key] for key in self.config.image_features + ] batch = self.normalize_targets(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) @@ -406,14 +416,18 @@ class ACT(nn.Module): n_1d_tokens += 1 self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model) if self.config.image_features: - self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2) + self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d( + config.dim_model // 2 + ) # Transformer decoder. # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries). self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model) # Final action regression head on the output of the transformer's decoder. - self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0]) + self.action_head = nn.Linear( + config.dim_model, self.config.action_feature.shape[0] + ) self._reset_parameters() @@ -461,14 +475,20 @@ class ACT(nn.Module): self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size ) # (B, 1, D) if self.config.robot_state_feature: - robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]) + robot_state_embed = self.vae_encoder_robot_state_input_proj( + batch["observation.state"] + ) robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) action_embed = self.vae_encoder_action_input_proj( batch["action"] ) # (B, S, D) if self.config.robot_state_feature: - vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D) + vae_encoder_input = [ + cls_embed, + robot_state_embed, + action_embed, + ] # (B, S+2, D) else: vae_encoder_input = [cls_embed, action_embed] vae_encoder_input = torch.cat(vae_encoder_input, axis=1) @@ -517,7 +537,9 @@ class ACT(nn.Module): ) # Robot state token. if self.config.robot_state_feature: - encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"])) + encoder_in_tokens.append( + self.encoder_robot_state_input_proj(batch["observation.state"]) + ) # Environment state token. if self.config.env_state_feature: encoder_in_tokens.append( @@ -534,7 +556,9 @@ class ACT(nn.Module): # For a list of images, the H and W may vary but H*W is constant. for img in batch["observation.images"]: cam_features = self.backbone(img)["feature_map"] - cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) + cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to( + dtype=cam_features.dtype + ) cam_features = self.encoder_img_feat_input_proj(cam_features) # Rearrange features to (sequence, batch, dim). diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index e73c65fe..32889ff0 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -205,11 +205,16 @@ class DiffusionConfig(PreTrainedConfig): def validate_features(self) -> None: if len(self.image_features) == 0 and self.env_state_feature is None: - raise ValueError("You must provide at least one image or the environment state among the inputs.") + raise ValueError( + "You must provide at least one image or the environment state among the inputs." + ) if self.crop_shape is not None: for key, image_ft in self.image_features.items(): - if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]: + if ( + self.crop_shape[0] > image_ft.shape[1] + or self.crop_shape[1] > image_ft.shape[2] + ): raise ValueError( f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} " f"for `crop_shape` and {image_ft.shape} for " diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index d331dddf..cb84292b 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -70,7 +70,9 @@ class DiffusionPolicy(PreTrainedPolicy): config.validate_features() self.config = config - self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_inputs = Normalize( + config.input_features, config.normalization_mapping, dataset_stats + ) self.normalize_targets = Normalize( config.output_features, config.normalization_mapping, dataset_stats ) @@ -97,7 +99,9 @@ class DiffusionPolicy(PreTrainedPolicy): if self.config.image_features: self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps) if self.config.env_state_feature: - self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps) + self._queues["observation.environment_state"] = deque( + maxlen=self.config.n_obs_steps + ) @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: @@ -123,7 +127,9 @@ class DiffusionPolicy(PreTrainedPolicy): """ batch = self.normalize_inputs(batch) if self.config.image_features: - batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch = dict( + batch + ) # shallow copy so that adding a key doesn't modify the original batch["observation.images"] = torch.stack( [batch[key] for key in self.config.image_features], dim=-4 ) @@ -151,7 +157,9 @@ class DiffusionPolicy(PreTrainedPolicy): """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) if self.config.image_features: - batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch = dict( + batch + ) # shallow copy so that adding a key doesn't modify the original batch["observation.images"] = torch.stack( [batch[key] for key in self.config.image_features], dim=-4 ) @@ -515,11 +523,15 @@ class DiffusionRgbEncoder(nn.Module): # Note: we have a check in the config class to make sure all images have the same shape. images_shape = next(iter(config.image_features.values())).shape - dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:] + dummy_shape_h_w = ( + config.crop_shape if config.crop_shape is not None else images_shape[1:] + ) dummy_shape = (1, images_shape[0], *dummy_shape_h_w) feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:] - self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints) + self.pool = SpatialSoftmax( + feature_map_shape, num_kp=config.spatial_softmax_num_keypoints + ) self.feature_dim = config.spatial_softmax_num_keypoints * 2 self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim) self.relu = nn.ReLU() @@ -719,7 +731,9 @@ class DiffusionConditionalUnet1d(nn.Module): ) self.final_conv = nn.Sequential( - DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size), + DiffusionConv1dBlock( + config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size + ), nn.Conv1d(config.down_dims[0], config.action_feature.shape[0], 1), ) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index fed799d9..9423d774 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -104,7 +104,9 @@ def make_policy( PreTrainedPolicy: _description_ """ if bool(ds_meta) == bool(env_cfg): - raise ValueError("Either one of a dataset metadata or a sim env must be provided.") + raise ValueError( + "Either one of a dataset metadata or a sim env must be provided." + ) # NOTE: Currently, if you try to run vqbet with mps backend, you'll get this error. # TODO(aliberts, rcadene): Implement a check_backend_compatibility in policies? @@ -134,8 +136,12 @@ def make_policy( ) features = env_to_policy_features(env_cfg) - cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} - cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features} + cfg.output_features = { + key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION + } + cfg.input_features = { + key: ft for key, ft in features.items() if key not in cfg.output_features + } kwargs["config"] = cfg if cfg.pretrained_path: diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index 012c854d..46428700 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -82,25 +82,43 @@ def create_stats_buffers( if stats: if isinstance(stats[key]["mean"], np.ndarray): if norm_mode is NormalizationMode.MEAN_STD: - buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32) - buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32) + buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to( + dtype=torch.float32 + ) + buffer["std"].data = torch.from_numpy(stats[key]["std"]).to( + dtype=torch.float32 + ) elif norm_mode is NormalizationMode.MIN_MAX: - buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32) - buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32) + buffer["min"].data = torch.from_numpy(stats[key]["min"]).to( + dtype=torch.float32 + ) + buffer["max"].data = torch.from_numpy(stats[key]["max"]).to( + dtype=torch.float32 + ) elif isinstance(stats[key]["mean"], torch.Tensor): # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated # tensors anywhere (for example, when we use the same stats for normalization and # unnormalization). See the logic here # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. if norm_mode is NormalizationMode.MEAN_STD: - buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32) - buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32) + buffer["mean"].data = ( + stats[key]["mean"].clone().to(dtype=torch.float32) + ) + buffer["std"].data = ( + stats[key]["std"].clone().to(dtype=torch.float32) + ) elif norm_mode is NormalizationMode.MIN_MAX: - buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32) - buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32) + buffer["min"].data = ( + stats[key]["min"].clone().to(dtype=torch.float32) + ) + buffer["max"].data = ( + stats[key]["max"].clone().to(dtype=torch.float32) + ) else: type_ = type(stats[key]["mean"]) - raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.") + raise ValueError( + f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead." + ) stats_buffers[key] = buffer return stats_buffers diff --git a/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py b/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py index 6bd7c91f..dc729eea 100644 --- a/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py +++ b/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py @@ -44,7 +44,9 @@ def main(): else: dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human" - ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch" + ckpt_torch_dir = ( + Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch" + ) ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}" save_dir = Path(f"../openpi/data/{model_name}/save") @@ -70,7 +72,9 @@ def main(): # Create LeRobot batch from Jax batch = {} for cam_key, uint_chw_array in example["images"].items(): - batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0 + batch[f"observation.images.{cam_key}"] = ( + torch.from_numpy(uint_chw_array) / 255.0 + ) batch["observation.state"] = torch.from_numpy(example["state"]) batch["action"] = torch.from_numpy(outputs["actions"]) batch["task"] = example["prompt"] diff --git a/lerobot/common/policies/pi0/conversion_scripts/conversion_utils.py b/lerobot/common/policies/pi0/conversion_scripts/conversion_utils.py index 8835da31..6b8406a9 100644 --- a/lerobot/common/policies/pi0/conversion_scripts/conversion_utils.py +++ b/lerobot/common/policies/pi0/conversion_scripts/conversion_utils.py @@ -54,7 +54,9 @@ def get_paligemma_config(precision: str): "projector_hidden_act": "gelu_fast", "vision_use_head": False, } - final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config) + final_config = PaliGemmaConfig( + text_config=text_config, vision_config=vision_config, **config + ) return final_config diff --git a/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py b/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py index 73ff506f..9fa00fbf 100644 --- a/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py +++ b/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py @@ -61,7 +61,11 @@ from lerobot.common.policies.pi0.conversion_scripts.conversion_utils import ( ) from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy -PRECISIONS = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16} +PRECISIONS = { + "bfloat16": torch.bfloat16, + "float32": torch.float32, + "float16": torch.float16, +} def slice_paligemma_state_dict(state_dict, config): @@ -318,7 +322,9 @@ def update_keys_with_prefix(d: dict, prefix: str) -> dict: return {f"{prefix}{key}": value for key, value in d.items()} -def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str): +def convert_pi0_checkpoint( + checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str +): # Break down orbax ckpts - they are in OCDBT initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir) # process projection params @@ -378,7 +384,9 @@ def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: st # gemma_config=gemma_config, paligemma_config=paligemma_config) pi0_model = PI0Policy(pi0_config) - paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.") + paligemma_params = update_keys_with_prefix( + paligemma_params, "model.paligemma_with_expert." + ) gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.") projection_params = update_keys_with_prefix(projection_params, "model.") diff --git a/lerobot/common/policies/pi0/flex_attention.py b/lerobot/common/policies/pi0/flex_attention.py index 35628cdd..0bcb15b9 100644 --- a/lerobot/common/policies/pi0/flex_attention.py +++ b/lerobot/common/policies/pi0/flex_attention.py @@ -48,18 +48,32 @@ def flex_attention_forward( key_states = key_states[:, :, :, None, :] key_states = key_states.expand( - batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim + batch_size, + key_states.shape[1], + num_key_value_heads, + num_key_value_groups, + head_dim, ) key_states = key_states.reshape( - batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim + batch_size, + key_states.shape[1], + num_key_value_heads * num_key_value_groups, + head_dim, ) value_states = value_states[:, :, :, None, :] value_states = value_states.expand( - batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim + batch_size, + value_states.shape[1], + num_key_value_heads, + num_key_value_groups, + head_dim, ) value_states = value_states.reshape( - batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim + batch_size, + value_states.shape[1], + num_key_value_heads * num_key_value_groups, + head_dim, ) query_states = query_states.transpose(1, 2) diff --git a/lerobot/common/policies/pi0/modeling_pi0.py b/lerobot/common/policies/pi0/modeling_pi0.py index bc53bf85..1e068c24 100644 --- a/lerobot/common/policies/pi0/modeling_pi0.py +++ b/lerobot/common/policies/pi0/modeling_pi0.py @@ -69,7 +69,11 @@ from lerobot.common.utils.utils import get_safe_dtype def create_sinusoidal_pos_embedding( - time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" + time: torch.tensor, + dimension: int, + min_period: float, + max_period: float, + device="cpu", ) -> Tensor: """Computes sine-cosine positional embedding vectors for scalar positions.""" if dimension % 2 != 0: @@ -189,7 +193,9 @@ def aloha_gripper_to_angular(value): # This is the inverse of the angular to linear transformation inside the Interbotix code. def linear_to_radian(linear_position, arm_length, horn_radius): - value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position) + value = (horn_radius**2 + linear_position**2 - arm_length**2) / ( + 2 * horn_radius * linear_position + ) return safe_arcsin(value) # The constants are taken from the Interbotix code. @@ -240,7 +246,9 @@ class PI0Policy(PreTrainedPolicy): super().__init__(config) config.validate_features() self.config = config - self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_inputs = Normalize( + config.input_features, config.normalization_mapping, dataset_stats + ) self.normalize_targets = Normalize( config.output_features, config.normalization_mapping, dataset_stats ) @@ -248,7 +256,9 @@ class PI0Policy(PreTrainedPolicy): config.output_features, config.normalization_mapping, dataset_stats ) - self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") + self.language_tokenizer = AutoTokenizer.from_pretrained( + "google/paligemma-3b-pt-224" + ) self.model = PI0FlowMatching(config) self.reset() @@ -261,7 +271,9 @@ class PI0Policy(PreTrainedPolicy): return self.parameters() @torch.no_grad - def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + def select_action( + self, batch: dict[str, Tensor], noise: Tensor | None = None + ) -> Tensor: """Select a single action given environment observations. This method wraps `select_actions` in order to return one action at a time for execution in the @@ -300,7 +312,9 @@ class PI0Policy(PreTrainedPolicy): self._action_queue.extend(actions.transpose(0, 1)) return self._action_queue.popleft() - def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]: + def forward( + self, batch: dict[str, Tensor], noise=None, time=None + ) -> tuple[Tensor, dict[str, Tensor]]: """Do a full training forward pass to compute the loss""" if self.config.adapt_to_pi_aloha: batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) @@ -316,7 +330,9 @@ class PI0Policy(PreTrainedPolicy): actions_is_pad = batch.get("actions_is_pad") loss_dict = {} - losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) + losses = self.model.forward( + images, img_masks, lang_tokens, lang_masks, state, actions, noise, time + ) loss_dict["losses_after_forward"] = losses.clone() if actions_is_pad is not None: @@ -343,7 +359,9 @@ class PI0Policy(PreTrainedPolicy): img_masks = [] present_img_keys = [key for key in self.config.image_features if key in batch] - missing_img_keys = [key for key in self.config.image_features if key not in batch] + missing_img_keys = [ + key for key in self.config.image_features if key not in batch + ] if len(present_img_keys) == 0: raise ValueError( @@ -355,7 +373,9 @@ class PI0Policy(PreTrainedPolicy): img = batch[key] if self.config.resize_imgs_with_padding is not None: - img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0) + img = resize_with_pad( + img, *self.config.resize_imgs_with_padding, pad_value=0 + ) # Normalize from range [0,1] to [-1,1] as expacted by siglip img = img * 2.0 - 1.0 @@ -394,7 +414,9 @@ class PI0Policy(PreTrainedPolicy): return_tensors="pt", ) lang_tokens = tokenized_prompt["input_ids"].to(device=device) - lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool) + lang_masks = tokenized_prompt["attention_mask"].to( + device=device, dtype=torch.bool + ) return lang_tokens, lang_masks @@ -413,7 +435,9 @@ class PI0Policy(PreTrainedPolicy): actions[:, :, motor_idx] *= -1 # Reverse the gripper transformation that is being applied by the Aloha runtime. for motor_idx in [6, 13]: - actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx]) + actions[:, :, motor_idx] = aloha_gripper_from_angular( + actions[:, :, motor_idx] + ) return actions def _pi_aloha_encode_actions_inv(self, actions): @@ -422,7 +446,9 @@ class PI0Policy(PreTrainedPolicy): actions[:, :, motor_idx] *= -1 # Reverse the gripper transformation that is being applied by the Aloha runtime. for motor_idx in [6, 13]: - actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx]) + actions[:, :, motor_idx] = aloha_gripper_from_angular_inv( + actions[:, :, motor_idx] + ) return actions def prepare_state(self, batch): @@ -472,15 +498,25 @@ class PI0FlowMatching(nn.Module): train_expert_only=self.config.train_expert_only, attention_implementation=self.config.attention_implementation, ) - self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config) + self.paligemma_with_expert = PaliGemmaWithExpertModel( + paligemma_with_export_config + ) # Projections are float32 self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width) - self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width) - self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim) + self.action_in_proj = nn.Linear( + self.config.max_action_dim, self.config.proj_width + ) + self.action_out_proj = nn.Linear( + self.config.proj_width, self.config.max_action_dim + ) - self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width) - self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width) + self.action_time_mlp_in = nn.Linear( + self.config.proj_width * 2, self.config.proj_width + ) + self.action_time_mlp_out = nn.Linear( + self.config.proj_width, self.config.proj_width + ) self.set_requires_grad() @@ -524,7 +560,9 @@ class PI0FlowMatching(nn.Module): # Normalize image embeddings img_emb_dim = img_emb.shape[-1] - img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device) + img_emb = img_emb * torch.tensor( + img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device + ) bsize, num_img_embs = img_emb.shape[:2] img_mask = img_mask[:, None].expand(bsize, num_img_embs) @@ -577,7 +615,11 @@ class PI0FlowMatching(nn.Module): # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1] time_emb = create_sinusoidal_pos_embedding( - timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device + timestep, + self.config.proj_width, + min_period=4e-3, + max_period=4.0, + device=device, ) time_emb = time_emb.type(dtype=dtype) @@ -595,7 +637,9 @@ class PI0FlowMatching(nn.Module): embs.append(action_time_emb) bsize, action_time_dim = action_time_emb.shape[:2] - action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device) + action_time_mask = torch.ones( + bsize, action_time_dim, dtype=torch.bool, device=device + ) pad_masks.append(action_time_mask) # Set attention masks so that image, language and state inputs do not attend to action tokens @@ -609,7 +653,15 @@ class PI0FlowMatching(nn.Module): return embs, pad_masks, att_masks def forward( - self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None + self, + images, + img_masks, + lang_tokens, + lang_masks, + state, + actions, + noise=None, + time=None, ) -> Tensor: """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" if noise is None: @@ -625,7 +677,9 @@ class PI0FlowMatching(nn.Module): prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( images, img_masks, lang_tokens, lang_masks ) - suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time) + suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix( + state, x_t, time + ) pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) @@ -649,13 +703,19 @@ class PI0FlowMatching(nn.Module): losses = F.mse_loss(u_t, v_t, reduction="none") return losses - def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor: + def sample_actions( + self, images, img_masks, lang_tokens, lang_masks, state, noise=None + ) -> Tensor: """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" bsize = state.shape[0] device = state.device if noise is None: - actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim) + actions_shape = ( + bsize, + self.config.n_action_steps, + self.config.max_action_dim, + ) noise = self.sample_noise(actions_shape, device) prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( @@ -703,12 +763,16 @@ class PI0FlowMatching(nn.Module): timestep, ): """Apply one denoising step of the noise `x_t` at a given timestep.""" - suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep) + suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix( + state, x_t, timestep + ) suffix_len = suffix_pad_masks.shape[1] batch_size = prefix_pad_masks.shape[0] prefix_len = prefix_pad_masks.shape[1] - prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len) + prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand( + batch_size, suffix_len, prefix_len + ) suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) diff --git a/lerobot/common/policies/pi0/paligemma_with_expert.py b/lerobot/common/policies/pi0/paligemma_with_expert.py index 76e2ce60..fc0ae065 100644 --- a/lerobot/common/policies/pi0/paligemma_with_expert.py +++ b/lerobot/common/policies/pi0/paligemma_with_expert.py @@ -39,9 +39,13 @@ def apply_rope(x, positions, max_wavelength=10_000): dtype = x.dtype x = x.to(torch.float32) - freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device) + freq_exponents = (2.0 / x.shape[-1]) * torch.arange( + d_half, dtype=torch.float32, device=device + ) timescale = max_wavelength**freq_exponents - radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32) + radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to( + torch.float32 + ) radians = radians[..., None, :] @@ -174,7 +178,9 @@ class PaliGemmaWithExpertModel(PreTrainedModel): def __init__(self, config: PaliGemmaWithExpertConfig): super().__init__(config=config) self.config = config - self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config) + self.paligemma = PaliGemmaForConditionalGeneration( + config=config.paligemma_config + ) self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config) # Remove unused embed_tokens self.gemma_expert.model.embed_tokens = None @@ -291,14 +297,22 @@ class PaliGemmaWithExpertModel(PreTrainedModel): # so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach # the max len, then we (for instance) double the cache size. This implementation already exists # in `transformers`. (molbap) - key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1) + key_states = torch.cat( + [past_key_values[layer_idx]["key_states"], key_states], dim=1 + ) value_states = torch.cat( - [past_key_values[layer_idx]["value_states"], value_states], dim=1 + [past_key_values[layer_idx]["value_states"], value_states], + dim=1, ) attention_interface = self.get_attention_interface() att_output = attention_interface( - attention_mask, batch_size, head_dim, query_states, key_states, value_states + attention_mask, + batch_size, + head_dim, + query_states, + key_states, + value_states, ) att_output = att_output.to(dtype=torch.bfloat16) @@ -358,15 +372,29 @@ class PaliGemmaWithExpertModel(PreTrainedModel): return attention_interface def flash_attention_forward( - self, attention_mask, batch_size, head_dim, query_states, key_states, value_states + self, + attention_mask, + batch_size, + head_dim, + query_states, + key_states, + value_states, ): raise NotImplementedError("FA2 is not implemented (yet)") def eager_attention_forward( - self, attention_mask, batch_size, head_dim, query_states, key_states, value_states + self, + attention_mask, + batch_size, + head_dim, + query_states, + key_states, + value_states, ): num_att_heads = self.config.paligemma_config.text_config.num_attention_heads - num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads + num_key_value_heads = ( + self.config.paligemma_config.text_config.num_key_value_heads + ) num_key_value_groups = num_att_heads // num_key_value_heads # query_states: batch_size, sequence_length, num_att_head, head_dim @@ -375,17 +403,31 @@ class PaliGemmaWithExpertModel(PreTrainedModel): sequence_length = key_states.shape[1] key_states = key_states[:, :, :, None, :].expand( - batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim + batch_size, + sequence_length, + num_key_value_heads, + num_key_value_groups, + head_dim, ) key_states = key_states.reshape( - batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim + batch_size, + sequence_length, + num_key_value_heads * num_key_value_groups, + head_dim, ) value_states = value_states[:, :, :, None, :].expand( - batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim + batch_size, + sequence_length, + num_key_value_heads, + num_key_value_groups, + head_dim, ) value_states = value_states.reshape( - batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim + batch_size, + sequence_length, + num_key_value_heads * num_key_value_groups, + head_dim, ) # Attention here is upcasted to float32 to match the original eager implementation. @@ -400,7 +442,9 @@ class PaliGemmaWithExpertModel(PreTrainedModel): att_weights *= head_dim**-0.5 big_neg = -2.3819763e38 # See gemma/modules.py - masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg) + masked_att_weights = torch.where( + attention_mask[:, None, :, :], att_weights, big_neg + ) probs = nn.functional.softmax(masked_att_weights, dim=-1) probs = probs.to(dtype=value_states.dtype) @@ -412,6 +456,8 @@ class PaliGemmaWithExpertModel(PreTrainedModel): att_output = att_output.permute(0, 2, 1, 3) # we use -1 because sequence length can change - att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim) + att_output = att_output.reshape( + batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim + ) return att_output diff --git a/lerobot/common/policies/pretrained.py b/lerobot/common/policies/pretrained.py index da4ef157..75d077b6 100644 --- a/lerobot/common/policies/pretrained.py +++ b/lerobot/common/policies/pretrained.py @@ -71,7 +71,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): def _save_pretrained(self, save_directory: Path) -> None: self.config._save_pretrained(save_directory) model_to_save = self.module if hasattr(self, "module") else self - save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE)) + save_model_as_safetensor( + model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE) + ) @classmethod def from_pretrained( @@ -110,7 +112,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): if os.path.isdir(model_id): print("Loading weights from local directory") model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE) - policy = cls._load_as_safetensor(instance, model_file, config.device, strict) + policy = cls._load_as_safetensor( + instance, model_file, config.device, strict + ) else: try: model_file = hf_hub_download( @@ -124,7 +128,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): token=token, local_files_only=local_files_only, ) - policy = cls._load_as_safetensor(instance, model_file, config.device, strict) + policy = cls._load_as_safetensor( + instance, model_file, config.device, strict + ) except HfHubHTTPError as e: raise FileNotFoundError( f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}" @@ -135,8 +141,12 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): return policy @classmethod - def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T: - if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"): + def _load_as_safetensor( + cls, model: T, model_file: str, map_location: str, strict: bool + ) -> T: + if packaging.version.parse(safetensors.__version__) < packaging.version.parse( + "0.4.3" + ): load_model_as_safetensor(model, model_file, strict=strict) if map_location != "cpu": logging.warning( @@ -147,7 +157,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): ) model.to(map_location) else: - safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) + safetensors.torch.load_model( + model, model_file, strict=strict, device=map_location + ) return model # def generate_model_card(self, *args, **kwargs) -> ModelCard: diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index c1b23b3d..ffa91468 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -639,9 +639,9 @@ class Policy(nn.Module): # Compute standard deviations if self.fixed_std is None: log_std = self.std_layer(outputs) - assert not torch.isnan( - log_std - ).any(), "[ERROR] log_std became NaN after std_layer!" + assert not torch.isnan(log_std).any(), ( + "[ERROR] log_std became NaN after std_layer!" + ) if self.use_tanh_squash: log_std = torch.tanh(log_std) diff --git a/lerobot/common/policies/tdmpc/configuration_tdmpc.py b/lerobot/common/policies/tdmpc/configuration_tdmpc.py index da1edfee..241e8d80 100644 --- a/lerobot/common/policies/tdmpc/configuration_tdmpc.py +++ b/lerobot/common/policies/tdmpc/configuration_tdmpc.py @@ -187,7 +187,9 @@ class TDMPCConfig(PreTrainedConfig): "If `n_action_steps > 1`, `use_mpc` must be set to `True`." ) if self.n_action_steps > self.horizon: - raise ValueError("`n_action_steps` must be less than or equal to `horizon`.") + raise ValueError( + "`n_action_steps` must be less than or equal to `horizon`." + ) def get_optimizer_preset(self) -> AdamConfig: return AdamConfig(lr=self.optimizer_lr) @@ -207,7 +209,9 @@ class TDMPCConfig(PreTrainedConfig): if image_ft.shape[-2] != image_ft.shape[-1]: # TODO(alexander-soare): This limitation is solely because of code in the random shift # augmentation. It should be able to be removed. - raise ValueError(f"Only square images are handled now. Got image shape {image_ft.shape}.") + raise ValueError( + f"Only square images are handled now. Got image shape {image_ft.shape}." + ) @property def observation_delta_indices(self) -> list: diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index 10e8bbcc..9ea6540e 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -39,7 +39,11 @@ from lerobot.common.constants import OBS_ENV, OBS_ROBOT from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig -from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues +from lerobot.common.policies.utils import ( + get_device_from_parameters, + get_output_shape, + populate_queues, +) class TDMPCPolicy(PreTrainedPolicy): @@ -63,7 +67,11 @@ class TDMPCPolicy(PreTrainedPolicy): config_class = TDMPCConfig name = "tdmpc" - def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None): + def __init__( + self, + config: TDMPCConfig, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + ): """ Args: config: Policy configuration class instance or None, in which case the default instantiation of @@ -75,7 +83,9 @@ class TDMPCPolicy(PreTrainedPolicy): config.validate_features() self.config = config - self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_inputs = Normalize( + config.input_features, config.normalization_mapping, dataset_stats + ) self.normalize_targets = Normalize( config.output_features, config.normalization_mapping, dataset_stats ) @@ -117,7 +127,9 @@ class TDMPCPolicy(PreTrainedPolicy): """Select a single action given environment observations.""" batch = self.normalize_inputs(batch) if self.config.image_features: - batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch = dict( + batch + ) # shallow copy so that adding a key doesn't modify the original batch["observation.image"] = batch[next(iter(self.config.image_features))] self._queues = populate_queues(self._queues, batch) @@ -201,7 +213,10 @@ class TDMPCPolicy(PreTrainedPolicy): # algorithm. # The initial mean and standard deviation for the cross-entropy method (CEM). mean = torch.zeros( - self.config.horizon, batch_size, self.config.action_feature.shape[0], device=device + self.config.horizon, + batch_size, + self.config.action_feature.shape[0], + device=device, ) # Maybe warm start CEM with the mean from the previous step. if self._prev_mean is not None: @@ -339,7 +354,9 @@ class TDMPCPolicy(PreTrainedPolicy): batch = self.normalize_inputs(batch) if self.config.image_features: - batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch = dict( + batch + ) # shallow copy so that adding a key doesn't modify the original batch["observation.image"] = batch[next(iter(self.config.image_features))] batch = self.normalize_targets(batch) @@ -371,7 +388,9 @@ class TDMPCPolicy(PreTrainedPolicy): current_observation[k] = observations[k][0] next_observations[k] = observations[k][1:] horizon, batch_size = next_observations[ - "observation.image" if self.config.image_features else "observation.environment_state" + "observation.image" + if self.config.image_features + else "observation.environment_state" ].shape[:2] # Run latent rollout using the latent dynamics model and policy model. @@ -569,7 +588,9 @@ class TDMPCTOLD(nn.Module): self.config = config self._encoder = TDMPCObservationEncoder(config) self._dynamics = nn.Sequential( - nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim), + nn.Linear( + config.latent_dim + config.action_feature.shape[0], config.mlp_dim + ), nn.LayerNorm(config.mlp_dim), nn.Mish(), nn.Linear(config.mlp_dim, config.mlp_dim), @@ -580,7 +601,9 @@ class TDMPCTOLD(nn.Module): nn.Sigmoid(), ) self._reward = nn.Sequential( - nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim), + nn.Linear( + config.latent_dim + config.action_feature.shape[0], config.mlp_dim + ), nn.LayerNorm(config.mlp_dim), nn.Mish(), nn.Linear(config.mlp_dim, config.mlp_dim), @@ -600,7 +623,10 @@ class TDMPCTOLD(nn.Module): self._Qs = nn.ModuleList( [ nn.Sequential( - nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim), + nn.Linear( + config.latent_dim + config.action_feature.shape[0], + config.mlp_dim, + ), nn.LayerNorm(config.mlp_dim), nn.Tanh(), nn.Linear(config.mlp_dim, config.mlp_dim), @@ -786,7 +812,9 @@ class TDMPCObservationEncoder(nn.Module): if config.robot_state_feature: self.state_enc_layers = nn.Sequential( - nn.Linear(config.robot_state_feature.shape[0], config.state_encoder_hidden_dim), + nn.Linear( + config.robot_state_feature.shape[0], config.state_encoder_hidden_dim + ), nn.ELU(), nn.Linear(config.state_encoder_hidden_dim, config.latent_dim), nn.LayerNorm(config.latent_dim), @@ -795,7 +823,9 @@ class TDMPCObservationEncoder(nn.Module): if config.env_state_feature: self.env_state_enc_layers = nn.Sequential( - nn.Linear(config.env_state_feature.shape[0], config.state_encoder_hidden_dim), + nn.Linear( + config.env_state_feature.shape[0], config.state_encoder_hidden_dim + ), nn.ELU(), nn.Linear(config.state_encoder_hidden_dim, config.latent_dim), nn.LayerNorm(config.latent_dim), @@ -813,7 +843,8 @@ class TDMPCObservationEncoder(nn.Module): if self.config.image_features: feat.append( flatten_forward_unflatten( - self.image_enc_layers, obs_dict[next(iter(self.config.image_features))] + self.image_enc_layers, + obs_dict[next(iter(self.config.image_features))], ) ) if self.config.env_state_feature: diff --git a/lerobot/common/policies/vqbet/configuration_vqbet.py b/lerobot/common/policies/vqbet/configuration_vqbet.py index 28e9c433..bf0cb2d0 100644 --- a/lerobot/common/policies/vqbet/configuration_vqbet.py +++ b/lerobot/common/policies/vqbet/configuration_vqbet.py @@ -172,7 +172,10 @@ class VQBeTConfig(PreTrainedConfig): if self.crop_shape is not None: for key, image_ft in self.image_features.items(): - if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]: + if ( + self.crop_shape[0] > image_ft.shape[1] + or self.crop_shape[1] > image_ft.shape[2] + ): raise ValueError( f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} " f"for `crop_shape` and {image_ft.shape} for " @@ -193,7 +196,12 @@ class VQBeTConfig(PreTrainedConfig): @property def action_delta_indices(self) -> list: - return list(range(1 - self.n_obs_steps, self.n_action_pred_token + self.action_chunk_size - 1)) + return list( + range( + 1 - self.n_obs_steps, + self.n_action_pred_token + self.action_chunk_size - 1, + ) + ) @property def reward_delta_indices(self) -> None: diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index 201870dd..a26b20d5 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -29,7 +29,11 @@ from torch import Tensor, nn from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy -from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues +from lerobot.common.policies.utils import ( + get_device_from_parameters, + get_output_shape, + populate_queues, +) from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ @@ -60,7 +64,9 @@ class VQBeTPolicy(PreTrainedPolicy): config.validate_features() self.config = config - self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_inputs = Normalize( + config.input_features, config.normalization_mapping, dataset_stats + ) self.normalize_targets = Normalize( config.output_features, config.normalization_mapping, dataset_stats ) @@ -91,11 +97,17 @@ class VQBeTPolicy(PreTrainedPolicy): if self.config.sequentially_select: decay_params = ( decay_params - + list(self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters()) - + list(self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters()) + + list( + self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters() + ) + + list( + self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters() + ) ) else: - decay_params = decay_params + list(self.vqbet.action_head.map_to_cbet_preds_bin.parameters()) + decay_params = decay_params + list( + self.vqbet.action_head.map_to_cbet_preds_bin.parameters() + ) return [ { @@ -133,8 +145,12 @@ class VQBeTPolicy(PreTrainedPolicy): """ batch = self.normalize_inputs(batch) - batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) + batch = dict( + batch + ) # shallow copy so that adding a key doesn't modify the original + batch["observation.images"] = torch.stack( + [batch[key] for key in self.config.image_features], dim=-4 + ) # Note: It's important that this happens after stacking the images into a single key. self._queues = populate_queues(self._queues, batch) @@ -165,8 +181,12 @@ class VQBeTPolicy(PreTrainedPolicy): def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) - batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) + batch = dict( + batch + ) # shallow copy so that adding a key doesn't modify the original + batch["observation.images"] = torch.stack( + [batch[key] for key in self.config.image_features], dim=-4 + ) batch = self.normalize_targets(batch) # VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181) if not self.vqbet.action_head.vqvae_model.discretized.item(): @@ -334,7 +354,8 @@ class VQBeTModel(nn.Module): # To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT. self.state_projector = MLP( - config.robot_state_feature.shape[0], hidden_channels=[self.config.gpt_input_dim] + config.robot_state_feature.shape[0], + hidden_channels=[self.config.gpt_input_dim], ) self.rgb_feature_projector = MLP( self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim] @@ -406,9 +427,9 @@ class VQBeTModel(nn.Module): features = self.policy(input_tokens) # len(self.config.input_features) is the number of different observation modes. # this line gets the index of action prompt tokens. - historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_features) + 1) + len( - self.config.input_features - ) + historical_act_pred_index = np.arange(0, n_obs_steps) * ( + len(self.config.input_features) + 1 + ) + len(self.config.input_features) # only extract the output tokens at the position of action query: # Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models, @@ -771,11 +792,15 @@ class VQBeTRgbEncoder(nn.Module): # height and width from `config.image_features`. images_shape = next(iter(config.image_features.values())).shape - dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:] + dummy_shape_h_w = ( + config.crop_shape if config.crop_shape is not None else images_shape[1:] + ) dummy_shape = (1, images_shape[0], *dummy_shape_h_w) feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:] - self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints) + self.pool = SpatialSoftmax( + feature_map_shape, num_kp=config.spatial_softmax_num_keypoints + ) self.feature_dim = config.spatial_softmax_num_keypoints * 2 self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim) self.relu = nn.ReLU() @@ -871,7 +896,8 @@ class VqVae(nn.Module): ) self.encoder = MLP( - in_channels=self.config.action_feature.shape[0] * self.config.action_chunk_size, + in_channels=self.config.action_feature.shape[0] + * self.config.action_chunk_size, hidden_channels=[ config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, @@ -899,9 +925,13 @@ class VqVae(nn.Module): # given latent vector, this function outputs the decoded action. output = self.decoder(latent) if self.config.action_chunk_size == 1: - return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0]) + return einops.rearrange( + output, "N (T A) -> N T A", A=self.config.action_feature.shape[0] + ) else: - return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0]) + return einops.rearrange( + output, "N (T A) -> N T A", A=self.config.action_feature.shape[0] + ) def get_code(self, state): # in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181) diff --git a/lerobot/common/policies/vqbet/vqbet_utils.py b/lerobot/common/policies/vqbet/vqbet_utils.py index 71e85ac0..3fed3a15 100644 --- a/lerobot/common/policies/vqbet/vqbet_utils.py +++ b/lerobot/common/policies/vqbet/vqbet_utils.py @@ -290,10 +290,10 @@ class GPT(nn.Module): param_dict = dict(self.named_parameters()) inter_params = decay & no_decay union_params = decay | no_decay - assert ( - len(inter_params) == 0 - ), "parameters {} made it into both decay/no_decay sets!".format( - str(inter_params) + assert len(inter_params) == 0, ( + "parameters {} made it into both decay/no_decay sets!".format( + str(inter_params) + ) ) assert len(param_dict.keys() - union_params) == 0, ( "parameters {} were not separated into either decay/no_decay set!".format( @@ -664,14 +664,14 @@ class VectorQuantize(nn.Module): self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only self.orthogonal_reg_max_codes = orthogonal_reg_max_codes - assert not ( - ema_update and learnable_codebook - ), "learnable codebook not compatible with EMA update" + assert not (ema_update and learnable_codebook), ( + "learnable codebook not compatible with EMA update" + ) assert 0 <= sync_update_v <= 1.0 - assert not ( - sync_update_v > 0.0 and not learnable_codebook - ), "learnable codebook must be turned on" + assert not (sync_update_v > 0.0 and not learnable_codebook), ( + "learnable codebook must be turned on" + ) self.sync_update_v = sync_update_v diff --git a/lerobot/common/robot_devices/cameras/configs.py b/lerobot/common/robot_devices/cameras/configs.py index 013419a9..ccbc0268 100644 --- a/lerobot/common/robot_devices/cameras/configs.py +++ b/lerobot/common/robot_devices/cameras/configs.py @@ -57,7 +57,9 @@ class OpenCVCameraConfig(CameraConfig): self.channels = 3 if self.rotation not in [-90, None, 90, 180]: - raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})") + raise ValueError( + f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})" + ) @CameraConfig.register_subclass("intelrealsense") @@ -102,8 +104,12 @@ class IntelRealSenseCameraConfig(CameraConfig): self.channels = 3 - at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None - at_least_one_is_none = self.fps is None or self.width is None or self.height is None + at_least_one_is_not_none = ( + self.fps is not None or self.width is not None or self.height is not None + ) + at_least_one_is_none = ( + self.fps is None or self.width is None or self.height is None + ) if at_least_one_is_not_none and at_least_one_is_none: raise ValueError( "For `fps`, `width` and `height`, either all of them need to be set, or none of them, " @@ -111,4 +117,6 @@ class IntelRealSenseCameraConfig(CameraConfig): ) if self.rotation not in [-90, None, 90, 180]: - raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})") + raise ValueError( + f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})" + ) diff --git a/lerobot/common/robot_devices/cameras/intelrealsense.py b/lerobot/common/robot_devices/cameras/intelrealsense.py index 1282007c..34fc61a9 100644 --- a/lerobot/common/robot_devices/cameras/intelrealsense.py +++ b/lerobot/common/robot_devices/cameras/intelrealsense.py @@ -303,7 +303,11 @@ class IntelRealSenseCamera: if self.fps and self.capture_width and self.capture_height: # TODO(rcadene): can we set rgb8 directly? config.enable_stream( - rs.stream.color, self.capture_width, self.capture_height, rs.format.rgb8, self.fps + rs.stream.color, + self.capture_width, + self.capture_height, + rs.format.rgb8, + self.fps, ) else: config.enable_stream(rs.stream.color) @@ -311,7 +315,11 @@ class IntelRealSenseCamera: if self.use_depth: if self.fps and self.capture_width and self.capture_height: config.enable_stream( - rs.stream.depth, self.capture_width, self.capture_height, rs.format.z16, self.fps + rs.stream.depth, + self.capture_width, + self.capture_height, + rs.format.z16, + self.fps, ) else: config.enable_stream(rs.stream.depth) diff --git a/lerobot/common/robot_devices/cameras/opencv.py b/lerobot/common/robot_devices/cameras/opencv.py index 48111e97..87f4e18e 100644 --- a/lerobot/common/robot_devices/cameras/opencv.py +++ b/lerobot/common/robot_devices/cameras/opencv.py @@ -144,7 +144,9 @@ def save_images_from_cameras( print("Connecting cameras") cameras = [] for cam_idx in camera_ids: - config = OpenCVCameraConfig(camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock) + config = OpenCVCameraConfig( + camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock + ) camera = OpenCVCamera(config) camera.connect() print( @@ -250,7 +252,9 @@ class OpenCVCamera: # Retrieve the camera index from a potentially symlinked path self.camera_index = get_camera_index_from_unix_port(self.port) else: - raise ValueError(f"Please check the provided camera_index: {self.camera_index}") + raise ValueError( + f"Please check the provided camera_index: {self.camera_index}" + ) # Store the raw (capture) resolution from the config. self.capture_width = config.width @@ -314,7 +318,11 @@ class OpenCVCamera: else cv2.CAP_ANY ) - camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index + camera_idx = ( + f"/dev/video{self.camera_index}" + if platform.system() == "Linux" + else self.camera_index + ) # First create a temporary camera trying to access `camera_index`, # and verify it is a valid camera by calling `isOpened`. tmp_camera = cv2.VideoCapture(camera_idx, backend) diff --git a/lerobot/common/robot_devices/cameras/utils.py b/lerobot/common/robot_devices/cameras/utils.py index c6431646..db5c9e9d 100644 --- a/lerobot/common/robot_devices/cameras/utils.py +++ b/lerobot/common/robot_devices/cameras/utils.py @@ -41,7 +41,9 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> list[C cameras[key] = OpenCVCamera(cfg) elif cfg.type == "intelrealsense": - from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera + from lerobot.common.robot_devices.cameras.intelrealsense import ( + IntelRealSenseCamera, + ) cameras[key] = IntelRealSenseCamera(cfg) else: @@ -58,7 +60,9 @@ def make_camera(camera_type, **kwargs) -> Camera: return OpenCVCamera(config) elif camera_type == "intelrealsense": - from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera + from lerobot.common.robot_devices.cameras.intelrealsense import ( + IntelRealSenseCamera, + ) config = IntelRealSenseCameraConfig(**kwargs) return IntelRealSenseCamera(config) diff --git a/lerobot/common/robot_devices/control_configs.py b/lerobot/common/robot_devices/control_configs.py index 0ecd8683..e3fe7e6a 100644 --- a/lerobot/common/robot_devices/control_configs.py +++ b/lerobot/common/robot_devices/control_configs.py @@ -93,7 +93,9 @@ class RecordControlConfig(ControlConfig): policy_path = parser.get_path_arg("control.policy") if policy_path: cli_overrides = parser.get_cli_overrides("control.policy") - self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) + self.policy = PreTrainedConfig.from_pretrained( + policy_path, cli_overrides=cli_overrides + ) self.policy.pretrained_path = policy_path diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index d12f07a4..12e6429d 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -282,7 +282,10 @@ def control_loop( if policy is not None: pred_action = predict_action( - observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp + observation, + policy, + get_safe_torch_device(policy.config.device), + policy.config.use_amp, ) # Action can eventually be clipped using `max_relative_target`, # so action actually sent is saved in the dataset. diff --git a/lerobot/common/robot_devices/motors/dynamixel.py b/lerobot/common/robot_devices/motors/dynamixel.py index 4721196d..d477b846 100644 --- a/lerobot/common/robot_devices/motors/dynamixel.py +++ b/lerobot/common/robot_devices/motors/dynamixel.py @@ -23,7 +23,10 @@ import numpy as np import tqdm from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig -from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError +from lerobot.common.robot_devices.utils import ( + RobotDeviceAlreadyConnectedError, + RobotDeviceNotConnectedError, +) from lerobot.common.utils.utils import capture_timestamp_utc PROTOCOL_VERSION = 2.0 diff --git a/lerobot/common/robot_devices/motors/feetech.py b/lerobot/common/robot_devices/motors/feetech.py index 0941428c..cfa249b8 100644 --- a/lerobot/common/robot_devices/motors/feetech.py +++ b/lerobot/common/robot_devices/motors/feetech.py @@ -23,7 +23,10 @@ import numpy as np import tqdm from lerobot.common.robot_devices.motors.configs import FeetechMotorsBusConfig -from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError +from lerobot.common.robot_devices.utils import ( + RobotDeviceAlreadyConnectedError, + RobotDeviceNotConnectedError, +) from lerobot.common.utils.utils import capture_timestamp_utc PROTOCOL_VERSION = 0 diff --git a/lerobot/common/robot_devices/motors/utils.py b/lerobot/common/robot_devices/motors/utils.py index bd86f4c6..acb6d40f 100644 --- a/lerobot/common/robot_devices/motors/utils.py +++ b/lerobot/common/robot_devices/motors/utils.py @@ -30,7 +30,9 @@ class MotorsBus(Protocol): def write(self): ... -def make_motors_buses_from_configs(motors_bus_configs: dict[str, MotorsBusConfig]) -> list[MotorsBus]: +def make_motors_buses_from_configs( + motors_bus_configs: dict[str, MotorsBusConfig], +) -> list[MotorsBus]: motors_buses = {} for key, cfg in motors_bus_configs.items(): diff --git a/lerobot/common/robot_devices/robots/configs.py b/lerobot/common/robot_devices/robots/configs.py index e940b442..d0a7f40a 100644 --- a/lerobot/common/robot_devices/robots/configs.py +++ b/lerobot/common/robot_devices/robots/configs.py @@ -69,9 +69,13 @@ class ManipulatorRobotConfig(RobotConfig): if not cam.mock: cam.mock = True - if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence): + if self.max_relative_target is not None and isinstance( + self.max_relative_target, Sequence + ): for name in self.follower_arms: - if len(self.follower_arms[name].motors) != len(self.max_relative_target): + if len(self.follower_arms[name].motors) != len( + self.max_relative_target + ): raise ValueError( f"len(max_relative_target)={len(self.max_relative_target)} but the follower arm with name {name} has " f"{len(self.follower_arms[name].motors)} motors. Please make sure that the " diff --git a/lerobot/common/robot_devices/robots/lekiwi_remote.py b/lerobot/common/robot_devices/robots/lekiwi_remote.py index 7bf52d21..a1ad4c9f 100644 --- a/lerobot/common/robot_devices/robots/lekiwi_remote.py +++ b/lerobot/common/robot_devices/robots/lekiwi_remote.py @@ -42,7 +42,9 @@ def run_camera_capture(cameras, images_lock, latest_images_dict, stop_event): local_dict = {} for name, cam in cameras.items(): frame = cam.async_read() - ret, buffer = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90]) + ret, buffer = cv2.imencode( + ".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90] + ) if ret: local_dict[name] = base64.b64encode(buffer).decode("utf-8") else: @@ -61,7 +63,9 @@ def calibrate_follower_arm(motors_bus, calib_dir_str): calib_dir.mkdir(parents=True, exist_ok=True) calib_file = calib_dir / "main_follower.json" try: - from lerobot.common.robot_devices.robots.feetech_calibration import run_arm_manual_calibration + from lerobot.common.robot_devices.robots.feetech_calibration import ( + run_arm_manual_calibration, + ) except ImportError: print("[WARNING] Calibration function not available. Skipping calibration.") return @@ -72,7 +76,9 @@ def calibrate_follower_arm(motors_bus, calib_dir_str): print(f"[INFO] Loaded calibration from {calib_file}") else: print("[INFO] Calibration file not found. Running manual calibration...") - calibration = run_arm_manual_calibration(motors_bus, "lekiwi", "follower_arm", "follower") + calibration = run_arm_manual_calibration( + motors_bus, "lekiwi", "follower_arm", "follower" + ) print(f"[INFO] Calibration complete. Saving to {calib_file}") with open(calib_file, "w") as f: json.dump(calibration, f) @@ -116,7 +122,14 @@ def run_lekiwi(robot_config): robot = LeKiwi(motors_bus) # Define the expected arm motor IDs. - arm_motor_ids = ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"] + arm_motor_ids = [ + "shoulder_pan", + "shoulder_lift", + "elbow_flex", + "wrist_flex", + "wrist_roll", + "gripper", + ] # Disable torque for each arm motor. for motor in arm_motor_ids: @@ -130,7 +143,9 @@ def run_lekiwi(robot_config): images_lock = threading.Lock() stop_event = threading.Event() cam_thread = threading.Thread( - target=run_camera_capture, args=(cameras, images_lock, latest_images_dict, stop_event), daemon=True + target=run_camera_capture, + args=(cameras, images_lock, latest_images_dict, stop_event), + daemon=True, ) cam_thread.start() @@ -159,7 +174,9 @@ def run_lekiwi(robot_config): f"[WARNING] Received {len(arm_positions)} arm positions, expected {len(arm_motor_ids)}" ) else: - for motor, pos in zip(arm_motor_ids, arm_positions, strict=False): + for motor, pos in zip( + arm_motor_ids, arm_positions, strict=False + ): motors_bus.write("Goal_Position", pos, motor) # Process wheel (base) commands. if "raw_velocity" in data: @@ -190,7 +207,9 @@ def run_lekiwi(robot_config): try: pos = motors_bus.read("Present_Position", motor) # Convert the position to a float (or use as is if already numeric). - follower_arm_state.append(float(pos) if not isinstance(pos, (int, float)) else pos) + follower_arm_state.append( + float(pos) if not isinstance(pos, (int, float)) else pos + ) except Exception as e: print(f"[ERROR] Reading motor {motor} failed: {e}") diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py index 334100ca..fbeaad5e 100644 --- a/lerobot/common/robot_devices/robots/manipulator.py +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -28,7 +28,10 @@ import numpy as np import torch from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs -from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs +from lerobot.common.robot_devices.motors.utils import ( + MotorsBus, + make_motors_buses_from_configs, +) from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig from lerobot.common.robot_devices.robots.utils import get_arm_id from lerobot.common.robot_devices.utils import ( diff --git a/lerobot/common/robot_devices/robots/mobile_manipulator.py b/lerobot/common/robot_devices/robots/mobile_manipulator.py index 385e218b..bf12c1d5 100644 --- a/lerobot/common/robot_devices/robots/mobile_manipulator.py +++ b/lerobot/common/robot_devices/robots/mobile_manipulator.py @@ -25,9 +25,14 @@ import zmq from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs from lerobot.common.robot_devices.motors.feetech import TorqueMode -from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs +from lerobot.common.robot_devices.motors.utils import ( + MotorsBus, + make_motors_buses_from_configs, +) from lerobot.common.robot_devices.robots.configs import LeKiwiRobotConfig -from lerobot.common.robot_devices.robots.feetech_calibration import run_arm_manual_calibration +from lerobot.common.robot_devices.robots.feetech_calibration import ( + run_arm_manual_calibration, +) from lerobot.common.robot_devices.robots.utils import get_arm_id from lerobot.common.robot_devices.utils import RobotDeviceNotConnectedError @@ -266,7 +271,9 @@ class MobileManipulator: calibration = json.load(f) else: print(f"Missing calibration file '{arm_calib_path}'") - calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type) + calibration = run_arm_manual_calibration( + arm, self.robot_type, name, arm_type + ) print(f"Calibration is done! Saving calibration file '{arm_calib_path}'") arm_calib_path.parent.mkdir(parents=True, exist_ok=True) with open(arm_calib_path, "w") as f: @@ -296,7 +303,9 @@ class MobileManipulator: bus.write("Torque_Enable", 0, motor_id) # Then filter out wheels - arm_only_dict = {k: v for k, v in bus.motors.items() if not k.startswith("wheel_")} + arm_only_dict = { + k: v for k, v in bus.motors.items() if not k.startswith("wheel_") + } if not arm_only_dict: continue @@ -324,7 +333,11 @@ class MobileManipulator: socks = dict(poller.poll(15)) if self.video_socket not in socks or socks[self.video_socket] != zmq.POLLIN: # No new data arrived → reuse ALL old data - return (self.last_frames, self.last_present_speed, self.last_remote_arm_state) + return ( + self.last_frames, + self.last_present_speed, + self.last_remote_arm_state, + ) # Drain all messages, keep only the last last_msg = None @@ -337,7 +350,11 @@ class MobileManipulator: if not last_msg: # No new message → also reuse old - return (self.last_frames, self.last_present_speed, self.last_remote_arm_state) + return ( + self.last_frames, + self.last_present_speed, + self.last_remote_arm_state, + ) # Decode only the final message try: @@ -360,7 +377,9 @@ class MobileManipulator: if new_arm_state is not None and frames is not None: self.last_frames = frames - remote_arm_state_tensor = torch.tensor(new_arm_state, dtype=torch.float32) + remote_arm_state_tensor = torch.tensor( + new_arm_state, dtype=torch.float32 + ) self.last_remote_arm_state = remote_arm_state_tensor present_speed = new_speed @@ -375,14 +394,21 @@ class MobileManipulator: except Exception as e: print(f"[DEBUG] Error decoding video message: {e}") # If decode fails, fall back to old data - return (self.last_frames, self.last_present_speed, self.last_remote_arm_state) + return ( + self.last_frames, + self.last_present_speed, + self.last_remote_arm_state, + ) return frames, present_speed, remote_arm_state_tensor def _process_present_speed(self, present_speed: dict) -> torch.Tensor: state_tensor = torch.zeros(3, dtype=torch.int32) if present_speed: - decoded = {key: MobileManipulator.raw_to_degps(value) for key, value in present_speed.items()} + decoded = { + key: MobileManipulator.raw_to_degps(value) + for key, value in present_speed.items() + } if "1" in decoded: state_tensor[0] = decoded["1"] if "2" in decoded: @@ -395,7 +421,9 @@ class MobileManipulator: self, record_data: bool = False ) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: if not self.is_connected: - raise RobotDeviceNotConnectedError("MobileManipulator is not connected. Run `connect()` first.") + raise RobotDeviceNotConnectedError( + "MobileManipulator is not connected. Run `connect()` first." + ) speed_setting = self.speed_levels[self.speed_index] xy_speed = speed_setting["xy"] # e.g. 0.1, 0.25, or 0.4 @@ -461,9 +489,15 @@ class MobileManipulator: body_state = self.wheel_raw_to_body(present_speed) - body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s + body_state_mm = ( + body_state[0] * 1000.0, + body_state[1] * 1000.0, + body_state[2], + ) # Convert x,y to mm/s wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32) - combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0) + combined_state_tensor = torch.cat( + (remote_arm_state_tensor, wheel_state_tensor), dim=0 + ) obs_dict = {"observation.state": combined_state_tensor} @@ -620,7 +654,11 @@ class MobileManipulator: # Convert each wheel’s angular speed (deg/s) to a raw integer. wheel_raw = [MobileManipulator.degps_to_raw(deg) for deg in wheel_degps] - return {"left_wheel": wheel_raw[0], "back_wheel": wheel_raw[1], "right_wheel": wheel_raw[2]} + return { + "left_wheel": wheel_raw[0], + "back_wheel": wheel_raw[1], + "right_wheel": wheel_raw[2], + } def wheel_raw_to_body( self, wheel_raw: dict, wheel_radius: float = 0.05, base_radius: float = 0.125 diff --git a/lerobot/common/robot_devices/robots/utils.py b/lerobot/common/robot_devices/robots/utils.py index dab514d5..d01fe580 100644 --- a/lerobot/common/robot_devices/robots/utils.py +++ b/lerobot/common/robot_devices/robots/utils.py @@ -72,7 +72,9 @@ def make_robot_from_config(config: RobotConfig): return ManipulatorRobot(config) elif isinstance(config, LeKiwiRobotConfig): - from lerobot.common.robot_devices.robots.mobile_manipulator import MobileManipulator + from lerobot.common.robot_devices.robots.mobile_manipulator import ( + MobileManipulator, + ) return MobileManipulator(config) else: diff --git a/lerobot/common/utils/hub.py b/lerobot/common/utils/hub.py index df7435c0..972b5b2b 100644 --- a/lerobot/common/utils/hub.py +++ b/lerobot/common/utils/hub.py @@ -69,7 +69,9 @@ class HubMixin: if push_to_hub: if repo_id is None: repo_id = save_directory.name # Defaults to `save_directory` name - return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs) + return self.push_to_hub( + repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs + ) return None def _save_pretrained(self, save_directory: Path) -> None: @@ -175,7 +177,9 @@ class HubMixin: The url of the commit of your object in the given repository. """ api = HfApi(token=token) - repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id + repo_id = api.create_repo( + repo_id=repo_id, private=private, exist_ok=True + ).repo_id if commit_message is None: if "Policy" in self.__class__.__name__: diff --git a/lerobot/common/utils/io_utils.py b/lerobot/common/utils/io_utils.py index c67d8e1e..339d864c 100644 --- a/lerobot/common/utils/io_utils.py +++ b/lerobot/common/utils/io_utils.py @@ -20,7 +20,16 @@ from typing import TypeVar import imageio -JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...] +JsonLike = ( + str + | int + | float + | bool + | None + | list["JsonLike"] + | dict[str, "JsonLike"] + | tuple["JsonLike", ...] +) T = TypeVar("T", bound=JsonLike) @@ -76,7 +85,9 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T: # Check length if len(target) != len(source): - raise ValueError(f"List length mismatch: expected {len(target)}, got {len(source)}") + raise ValueError( + f"List length mismatch: expected {len(target)}, got {len(source)}" + ) # Recursively update each element. for i in range(len(target)): @@ -88,10 +99,14 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T: # which we'll convert back to a tuple. elif isinstance(target, tuple): if not isinstance(source, list): - raise TypeError(f"Type mismatch: expected list (for tuple), got {type(source)}") + raise TypeError( + f"Type mismatch: expected list (for tuple), got {type(source)}" + ) if len(target) != len(source): - raise ValueError(f"Tuple length mismatch: expected {len(target)}, got {len(source)}") + raise ValueError( + f"Tuple length mismatch: expected {len(target)}, got {len(source)}" + ) # Convert each element, forming a new tuple. converted_items = [] @@ -105,7 +120,9 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T: else: # Check the exact type. If these must match 1:1, do: if type(target) is not type(source): - raise TypeError(f"Type mismatch: expected {type(target)}, got {type(source)}") + raise TypeError( + f"Type mismatch: expected {type(target)}, got {type(source)}" + ) return source # Perform the in-place/recursive deserialization diff --git a/lerobot/common/utils/logging_utils.py b/lerobot/common/utils/logging_utils.py index b99c348f..f38747d8 100644 --- a/lerobot/common/utils/logging_utils.py +++ b/lerobot/common/utils/logging_utils.py @@ -107,13 +107,17 @@ class MetricsTracker: self.episodes = self.samples / self._avg_samples_per_ep self.epochs = self.samples / self._num_frames - def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any: + def __getattr__( + self, name: str + ) -> int | dict[str, AverageMeter] | AverageMeter | Any: if name in self.__dict__: return self.__dict__[name] elif name in self.metrics: return self.metrics[name] else: - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) def __setattr__(self, name: str, value: Any) -> None: if name in self.__dict__: @@ -121,7 +125,9 @@ class MetricsTracker: elif name in self.metrics: self.metrics[name].update(value) else: - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) def step(self) -> None: """ diff --git a/lerobot/common/utils/random_utils.py b/lerobot/common/utils/random_utils.py index 3d9bf4dd..c9327125 100644 --- a/lerobot/common/utils/random_utils.py +++ b/lerobot/common/utils/random_utils.py @@ -42,7 +42,11 @@ def deserialize_python_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> Non """ Restores the rng state for `random` from a dictionary produced by `serialize_python_rng_state()`. """ - py_state = (rng_state_dict["py_rng_version"].item(), tuple(rng_state_dict["py_rng_state"].tolist()), None) + py_state = ( + rng_state_dict["py_rng_version"].item(), + tuple(rng_state_dict["py_rng_state"].tolist()), + None, + ) random.setstate(py_state) @@ -119,7 +123,9 @@ def deserialize_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None: """ py_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("py")} np_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("np")} - torch_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("torch")} + torch_rng_state_dict = { + k: v for k, v in rng_state_dict.items() if k.startswith("torch") + } deserialize_python_rng_state(py_rng_state_dict) deserialize_numpy_rng_state(np_rng_state_dict) diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index 272fa614..e0b6aafc 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -48,7 +48,9 @@ def auto_select_torch_device() -> torch.device: logging.info("Metal backend detected, using cuda.") return torch.device("mps") else: - logging.warning("No accelerated backend detected. Using default cpu, this will be slow.") + logging.warning( + "No accelerated backend detected. Using default cpu, this will be slow." + ) return torch.device("cpu") @@ -96,7 +98,9 @@ def is_torch_device_available(try_device: str) -> bool: elif try_device == "cpu": return True else: - raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.") + raise ValueError( + f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu." + ) def is_amp_available(device: str): @@ -219,7 +223,9 @@ def say(text, blocking=False): if blocking: subprocess.run(cmd, check=True) else: - subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0) + subprocess.Popen( + cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0 + ) def log_say(text, play_sounds, blocking=False): diff --git a/lerobot/common/utils/wandb_utils.py b/lerobot/common/utils/wandb_utils.py index 700ebea5..30915c3e 100644 --- a/lerobot/common/utils/wandb_utils.py +++ b/lerobot/common/utils/wandb_utils.py @@ -26,7 +26,9 @@ from lerobot.common.constants import PRETRAINED_MODEL_DIR from lerobot.configs.train import TrainPipelineConfig -def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str: +def cfg_to_group( + cfg: TrainPipelineConfig, return_list: bool = False +) -> list[str] | str: """Return a group name for logging. Optionally returns group name as list.""" lst = [ f"policy:{cfg.policy.type}", @@ -92,7 +94,9 @@ class WandBLogger: resume="must" if cfg.resume else None, ) print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) - logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") + logging.info( + f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}" + ) self._wandb = wandb def log_policy(self, checkpoint_dir: Path): @@ -104,7 +108,9 @@ class WandBLogger: artifact_name = f"{self._group}-{step_id}" artifact_name = get_safe_wandb_artifact_name(artifact_name) artifact = self._wandb.Artifact(artifact_name, type="model") - artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE) + artifact.add_file( + checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE + ) self._wandb.log_artifact(artifact) def log_dict(self, d: dict, step: int, mode: str = "train"): diff --git a/lerobot/configs/default.py b/lerobot/configs/default.py index b23bbb6d..96482d5e 100644 --- a/lerobot/configs/default.py +++ b/lerobot/configs/default.py @@ -33,7 +33,9 @@ class DatasetConfig: # Root directory where the dataset will be stored (e.g. 'dataset/path'). root: str | None = None episodes: list[int] | None = None - image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig) + image_transforms: ImageTransformsConfig = field( + default_factory=ImageTransformsConfig + ) revision: str | None = None use_imagenet_stats: bool = True video_backend: str = field(default_factory=get_safe_default_codec) diff --git a/lerobot/configs/eval.py b/lerobot/configs/eval.py index 16b35291..f1a26d47 100644 --- a/lerobot/configs/eval.py +++ b/lerobot/configs/eval.py @@ -40,7 +40,9 @@ class EvalPipelineConfig: policy_path = parser.get_path_arg("policy") if policy_path: cli_overrides = parser.get_cli_overrides("policy") - self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) + self.policy = PreTrainedConfig.from_pretrained( + policy_path, cli_overrides=cli_overrides + ) self.policy.pretrained_path = policy_path else: diff --git a/lerobot/configs/parser.py b/lerobot/configs/parser.py index 39e31515..3cdf0578 100644 --- a/lerobot/configs/parser.py +++ b/lerobot/configs/parser.py @@ -29,7 +29,9 @@ PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path" draccus.set_config_type("json") -def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None: +def get_cli_overrides( + field_name: str, args: Sequence[str] | None = None +) -> list[str] | None: """Parses arguments from cli at a given nested attribute level. For example, supposing the main script was called with: @@ -42,7 +44,10 @@ def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> lis args = sys.argv[1:] attr_level_args = [] detect_string = f"--{field_name}." - exclude_strings = (f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=", f"--{field_name}.{PATH_KEY}=") + exclude_strings = ( + f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=", + f"--{field_name}.{PATH_KEY}=", + ) for arg in args: if arg.startswith(detect_string) and not arg.startswith(exclude_strings): denested_arg = f"--{arg.removeprefix(detect_string)}" @@ -153,7 +158,9 @@ def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[ return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")] -def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | None = None) -> list[str]: +def filter_path_args( + fields_to_filter: str | list[str], args: Sequence[str] | None = None +) -> list[str]: """ Filters command-line arguments related to fields with specific path arguments. @@ -181,7 +188,9 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No argument=None, message=f"Cannot specify both --{field}.{PATH_KEY} and --{field}.{draccus.CHOICE_TYPE_KEY}", ) - filtered_args = [arg for arg in filtered_args if not arg.startswith(f"--{field}.")] + filtered_args = [ + arg for arg in filtered_args if not arg.startswith(f"--{field}.") + ] return filtered_args @@ -213,7 +222,9 @@ def wrap(config_path: Path | None = None): load_plugin(plugin_path) except PluginLoadError as e: # add the relevant CLI arg to the error message - raise PluginLoadError(f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}") from e + raise PluginLoadError( + f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}" + ) from e cli_args = filter_arg(plugin_cli_arg, cli_args) config_path_cli = parse_arg("config_path", cli_args) if has_method(argtype, "__get_path_fields__"): @@ -223,7 +234,9 @@ def wrap(config_path: Path | None = None): cli_args = filter_arg("config_path", cli_args) cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args) else: - cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args) + cfg = draccus.parse( + config_class=argtype, config_path=config_path, args=cli_args + ) response = fn(cfg, *args, **kwargs) return response diff --git a/lerobot/configs/policies.py b/lerobot/configs/policies.py index 022d1fb5..32c3125f 100644 --- a/lerobot/configs/policies.py +++ b/lerobot/configs/policies.py @@ -26,7 +26,11 @@ from huggingface_hub.errors import HfHubHTTPError from lerobot.common.optim.optimizers import OptimizerConfig from lerobot.common.optim.schedulers import LRSchedulerConfig from lerobot.common.utils.hub import HubMixin -from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available +from lerobot.common.utils.utils import ( + auto_select_torch_device, + is_amp_available, + is_torch_device_available, +) from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature # Generic variable that is either PreTrainedConfig or a subclass thereof @@ -64,7 +68,9 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): self.pretrained_path = None if not self.device or not is_torch_device_available(self.device): auto_device = auto_select_torch_device() - logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.") + logging.warning( + f"Device '{self.device}' is not available. Switching to '{auto_device}'." + ) self.device = auto_device.type # Automatically deactivate AMP if necessary @@ -118,7 +124,11 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): @property def image_features(self) -> dict[str, PolicyFeature]: - return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL} + return { + key: ft + for key, ft in self.input_features.items() + if ft.type is FeatureType.VISUAL + } @property def action_feature(self) -> PolicyFeature | None: diff --git a/lerobot/configs/train.py b/lerobot/configs/train.py index 7a787b83..b4adeb89 100644 --- a/lerobot/configs/train.py +++ b/lerobot/configs/train.py @@ -73,7 +73,9 @@ class TrainPipelineConfig(HubMixin): if policy_path: # Only load the policy config cli_overrides = parser.get_cli_overrides("policy") - self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) + self.policy = PreTrainedConfig.from_pretrained( + policy_path, cli_overrides=cli_overrides + ) self.policy.pretrained_path = policy_path elif self.resume: # The entire train config is already loaded, we just need to get the checkpoint dir @@ -97,7 +99,11 @@ class TrainPipelineConfig(HubMixin): else: self.job_name = f"{self.env.type}_{self.policy.type}" - if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir(): + if ( + not self.resume + and isinstance(self.output_dir, Path) + and self.output_dir.is_dir() + ): raise FileExistsError( f"Output directory {self.output_dir} already exists and resume is {self.resume}. " f"Please change your output directory so that {self.output_dir} is not overwritten." @@ -108,10 +114,16 @@ class TrainPipelineConfig(HubMixin): self.output_dir = Path("outputs/train") / train_dir if isinstance(self.dataset.repo_id, list): - raise NotImplementedError("LeRobotMultiDataset is not currently implemented.") + raise NotImplementedError( + "LeRobotMultiDataset is not currently implemented." + ) - if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None): - raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.") + if not self.use_policy_training_preset and ( + self.optimizer is None or self.scheduler is None + ): + raise ValueError( + "Optimizer and Scheduler must be set when the policy presets are not used." + ) elif self.use_policy_training_preset and not self.resume: self.optimizer = self.policy.get_optimizer_preset() self.scheduler = self.policy.get_scheduler_preset() @@ -125,7 +137,10 @@ class TrainPipelineConfig(HubMixin): return draccus.encode(self) def _save_pretrained(self, save_directory: Path) -> None: - with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"): + with ( + open(save_directory / TRAIN_CONFIG_NAME, "w") as f, + draccus.config_type("json"), + ): draccus.dump(self, f, indent=4) @classmethod diff --git a/lerobot/scripts/configure_motor.py b/lerobot/scripts/configure_motor.py index 3b395129..985181a8 100644 --- a/lerobot/scripts/configure_motor.py +++ b/lerobot/scripts/configure_motor.py @@ -38,7 +38,12 @@ def get_motor_bus_cls(brand: str) -> tuple: FeetechMotorsBus, ) - return FeetechMotorsBusConfig, FeetechMotorsBus, MODEL_BAUDRATE_TABLE, SCS_SERIES_BAUDRATE_TABLE + return ( + FeetechMotorsBusConfig, + FeetechMotorsBus, + MODEL_BAUDRATE_TABLE, + SCS_SERIES_BAUDRATE_TABLE, + ) elif brand == "dynamixel": from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig @@ -48,7 +53,12 @@ def get_motor_bus_cls(brand: str) -> tuple: DynamixelMotorsBus, ) - return DynamixelMotorsBusConfig, DynamixelMotorsBus, MODEL_BAUDRATE_TABLE, X_SERIES_BAUDRATE_TABLE + return ( + DynamixelMotorsBusConfig, + DynamixelMotorsBus, + MODEL_BAUDRATE_TABLE, + X_SERIES_BAUDRATE_TABLE, + ) else: raise ValueError( @@ -57,8 +67,8 @@ def get_motor_bus_cls(brand: str) -> tuple: def configure_motor(port, brand, model, motor_idx_des, baudrate_des): - motor_bus_config_cls, motor_bus_cls, model_baudrate_table, series_baudrate_table = get_motor_bus_cls( - brand + motor_bus_config_cls, motor_bus_cls, model_baudrate_table, series_baudrate_table = ( + get_motor_bus_cls(brand) ) # Check if the provided model exists in the model_baud_rate_table @@ -72,7 +82,9 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des): motor_index_arbitrary = motor_idx_des # Use the motor ID passed via argument motor_model = model # Use the motor model passed via argument - config = motor_bus_config_cls(port=port, motors={motor_name: (motor_index_arbitrary, motor_model)}) + config = motor_bus_config_cls( + port=port, motors={motor_name: (motor_index_arbitrary, motor_model)} + ) # Initialize the MotorBus with the correct port and motor configurations motor_bus = motor_bus_cls(config=config) @@ -139,8 +151,12 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des): print(f"Setting its index to desired index {motor_idx_des}") if brand == "feetech": - motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0) - motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "ID", motor_idx_des) + motor_bus.write_with_motor_ids( + motor_bus.motor_models, motor_index, "Lock", 0 + ) + motor_bus.write_with_motor_ids( + motor_bus.motor_models, motor_index, "ID", motor_idx_des + ) present_idx = motor_bus.read_with_motor_ids( motor_bus.motor_models, motor_idx_des, "ID", num_retry=2 diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 016ce2e9..f24aa3ac 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -156,7 +156,6 @@ from lerobot.common.robot_devices.control_utils import ( log_control_info, record_episode, reset_environment, - reset_follower_position, sanity_check_dataset_name, sanity_check_dataset_robot_compatibility, stop_recording, @@ -251,7 +250,8 @@ def record( if len(robot.cameras) > 0: dataset.start_image_writer( num_processes=cfg.num_image_writer_processes, - num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras), + num_threads=cfg.num_image_writer_threads_per_camera + * len(robot.cameras), ) sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video) else: @@ -264,14 +264,19 @@ def record( robot=robot, use_videos=cfg.video, image_writer_processes=cfg.num_image_writer_processes, - image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras), + image_writer_threads=cfg.num_image_writer_threads_per_camera + * len(robot.cameras), ) # Load pretrained policy - policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) + policy = ( + None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) + ) # Load pretrained policy - policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) + policy = ( + None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) + ) if not robot.is_connected: robot.connect() @@ -286,7 +291,14 @@ def record( # 3. place the cameras windows on screen enable_teleoperation = policy is None log_say("Warmup record", cfg.play_sounds) - warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps) + warmup_record( + robot, + events, + enable_teleoperation, + cfg.warmup_time_s, + cfg.display_cameras, + cfg.fps, + ) if has_method(robot, "teleop_safety_stop"): robot.teleop_safety_stop() diff --git a/lerobot/scripts/control_sim_robot.py b/lerobot/scripts/control_sim_robot.py index 1eea4223..2bf72640 100644 --- a/lerobot/scripts/control_sim_robot.py +++ b/lerobot/scripts/control_sim_robot.py @@ -262,7 +262,11 @@ def record( shape = env.observation_space[key].shape if not key.startswith("observation.image."): key = "observation.image." + key - features[key] = {"dtype": "video", "names": ["channels", "height", "width"], "shape": shape} + features[key] = { + "dtype": "video", + "names": ["channels", "height", "width"], + "shape": shape, + } for key, obs_key in state_keys_dict.items(): features[key] = { diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 3c679530..f5a15538 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -152,7 +152,8 @@ def rollout( all_observations.append(deepcopy(observation)) observation = { - key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation + key: observation[key].to(device, non_blocking=device.type == "cuda") + for key in observation } with torch.inference_mode(): @@ -511,10 +512,14 @@ def eval_main(cfg: EvalPipelineConfig): torch.backends.cuda.matmul.allow_tf32 = True set_seed(cfg.seed) - logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") + logging.info( + colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}" + ) logging.info("Making environment.") - env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) + env = make_env( + cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs + ) logging.info("Making policy.") @@ -524,7 +529,12 @@ def eval_main(cfg: EvalPipelineConfig): ) policy.eval() - with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(): + with ( + torch.no_grad(), + torch.autocast(device_type=device.type) + if cfg.policy.use_amp + else nullcontext(), + ): info = eval_policy( env, policy, diff --git a/lerobot/scripts/server/end_effector_control_utils.py b/lerobot/scripts/server/end_effector_control_utils.py index 253a8ebd..b7522d29 100644 --- a/lerobot/scripts/server/end_effector_control_utils.py +++ b/lerobot/scripts/server/end_effector_control_utils.py @@ -480,7 +480,7 @@ def test_forward_kinematics(robot, fps=10): obs = robot.capture_observation() joint_positions = obs["observation.state"].cpu().numpy() ee_pos = RobotKinematics.fk_gripper_tip(joint_positions) - logging.info(f"EE Position: {ee_pos[:3,3]}") + logging.info(f"EE Position: {ee_pos[:3, 3]}") busy_wait(1 / fps - (time.perf_counter() - loop_start_time)) @@ -519,7 +519,7 @@ def teleoperate_inverse_kinematics_with_leader(robot, fps=10): joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func ) robot.send_action(torch.from_numpy(target_joint_state)) - logging.info(f"Leader EE: {leader_ee[:3,3]}, Follower EE: {ee_pos[:3,3]}") + logging.info(f"Leader EE: {leader_ee[:3, 3]}, Follower EE: {ee_pos[:3, 3]}") busy_wait(1 / fps - (time.perf_counter() - loop_start_time)) @@ -574,9 +574,9 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10): # Logging logging.info( - f"Current EE: {current_ee_pos[:3,3]}, Desired EE: {desired_ee_pos[:3,3]}" + f"Current EE: {current_ee_pos[:3, 3]}, Desired EE: {desired_ee_pos[:3, 3]}" ) - logging.info(f"Delta EE: {ee_delta[:3,3]}") + logging.info(f"Delta EE: {ee_delta[:3, 3]}") busy_wait(1 / fps - (time.perf_counter() - loop_start_time)) diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 728afdfa..2dd7bdbb 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -1087,9 +1087,9 @@ class GamepadControlWrapper(gym.Wrapper): class ActionScaleWrapper(gym.ActionWrapper): def __init__(self, env, ee_action_space_params=None): super().__init__(env) - assert ( - ee_action_space_params is not None - ), "TODO: method implemented for ee action space only so far" + assert ee_action_space_params is not None, ( + "TODO: method implemented for ee action space only so far" + ) self.scale_vector = np.array( [ [ @@ -1546,4 +1546,4 @@ if __name__ == "__main__": busy_wait(1 / args.fps - dt_s) logging.info(f"Success after 20 steps {sucesses}") - logging.info(f"success rate {sum(sucesses)/ len(sucesses)}") + logging.info(f"success rate {sum(sucesses) / len(sucesses)}") diff --git a/lerobot/scripts/server/network_utils.py b/lerobot/scripts/server/network_utils.py index f5e8973b..03ca06ca 100644 --- a/lerobot/scripts/server/network_utils.py +++ b/lerobot/scripts/server/network_utils.py @@ -41,7 +41,7 @@ def send_bytes_in_chunks( logging_method = logging.info if not silent else logging.debug - logging_method(f"{log_prefix} Buffer size {size_in_bytes/1024/1024} MB with") + logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with") while sent_bytes < size_in_bytes: transfer_state = hilserl_pb2.TransferState.TRANSFER_MIDDLE @@ -60,7 +60,7 @@ def send_bytes_in_chunks( f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}" ) - logging_method(f"{log_prefix} Published {sent_bytes/1024/1024} MB") + logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB") def receive_bytes_in_chunks( diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 5dee312b..d9cbc04b 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -223,14 +223,18 @@ def train(cfg: TrainPipelineConfig): step = 0 # number of policy updates (forward + backward + optim) if cfg.resume: - step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler) + step, optimizer, lr_scheduler = load_training_state( + cfg.checkpoint_path, optimizer, lr_scheduler + ) num_learnable_params = sum( p.numel() for p in policy.parameters() if p.requires_grad ) num_total_params = sum(p.numel() for p in policy.parameters()) - logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") + logging.info( + colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}" + ) if cfg.env is not None: logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})") @@ -273,7 +277,11 @@ def train(cfg: TrainPipelineConfig): } train_tracker = MetricsTracker( - cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step + cfg.batch_size, + dataset.num_frames, + dataset.num_episodes, + train_metrics, + initial_step=step, ) logging.info("Start offline training on a fixed dataset") @@ -327,7 +335,9 @@ def train(cfg: TrainPipelineConfig): logging.info(f"Eval policy at step {step}") with ( torch.no_grad(), - torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(), + torch.autocast(device_type=device.type) + if cfg.policy.use_amp + else nullcontext(), ): eval_info = eval_policy( eval_env, @@ -344,7 +354,11 @@ def train(cfg: TrainPipelineConfig): "eval_s": AverageMeter("eval_s", ":.3f"), } eval_tracker = MetricsTracker( - cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step + cfg.batch_size, + dataset.num_frames, + dataset.num_episodes, + eval_metrics, + initial_step=step, ) eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s") eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward") diff --git a/lerobot/scripts/train_hilserl_classifier.py b/lerobot/scripts/train_hilserl_classifier.py index 1cae0183..5d877bba 100644 --- a/lerobot/scripts/train_hilserl_classifier.py +++ b/lerobot/scripts/train_hilserl_classifier.py @@ -426,7 +426,7 @@ def train( # Training loop with validation and checkpointing for epoch in range(cfg.training.num_epochs): - logging.info(f"\nEpoch {epoch+1}/{cfg.training.num_epochs}") + logging.info(f"\nEpoch {epoch + 1}/{cfg.training.num_epochs}") train_epoch( model, @@ -470,7 +470,7 @@ def train( policy=model, optimizer=optimizer, scheduler=None, - identifier=f"{epoch+1:06d}", + identifier=f"{epoch + 1:06d}", ) step += len(train_loader) diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 340e6516..65779d85 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -94,9 +94,9 @@ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray: assert chw_float32_torch.dtype == torch.float32 assert chw_float32_torch.ndim == 3 c, h, w = chw_float32_torch.shape - assert ( - c < h and c < w - ), f"expect channel first images, but instead {chw_float32_torch.shape}" + assert c < h and c < w, ( + f"expect channel first images, but instead {chw_float32_torch.shape}" + ) hwc_uint8_numpy = ( (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy() ) diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py index 51dbf4c2..71f13148 100644 --- a/lerobot/scripts/visualize_dataset_html.py +++ b/lerobot/scripts/visualize_dataset_html.py @@ -158,7 +158,9 @@ def run_server( 400, ) dataset_version = ( - str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version + str(dataset.meta._version) + if isinstance(dataset, LeRobotDataset) + else dataset.codebase_version ) match = re.search(r"v(\d+)\.", dataset_version) if match: @@ -166,7 +168,9 @@ def run_server( if major_version < 2: return "Make sure to convert your LeRobotDataset to v2 & above." - episode_data_csv_str, columns, ignored_columns = get_episode_data(dataset, episode_id) + episode_data_csv_str, columns, ignored_columns = get_episode_data( + dataset, episode_id + ) dataset_info = { "repo_id": f"{dataset_namespace}/{dataset_name}", "num_samples": dataset.num_frames @@ -208,7 +212,8 @@ def run_server( ] response = requests.get( - f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", timeout=5 + f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", + timeout=5, ) response.raise_for_status() # Split into lines and parse each line as JSON @@ -256,7 +261,11 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index) This file will be loaded by Dygraph javascript to plot data in real time.""" columns = [] - selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]] + selected_columns = [ + col + for col, ft in dataset.features.items() + if ft["dtype"] in ["float32", "int32"] + ] selected_columns.remove("timestamp") ignored_columns = [] @@ -361,7 +370,8 @@ def get_episode_language_instruction( def get_dataset_info(repo_id: str) -> IterableNamespace: response = requests.get( - f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", timeout=5 + f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", + timeout=5, ) response.raise_for_status() # Raises an HTTPError for bad responses dataset_info = response.json() diff --git a/lerobot/scripts/visualize_image_transforms.py b/lerobot/scripts/visualize_image_transforms.py index 80935d32..4b6bb900 100644 --- a/lerobot/scripts/visualize_image_transforms.py +++ b/lerobot/scripts/visualize_image_transforms.py @@ -47,7 +47,9 @@ OUTPUT_DIR = Path("outputs/image_transforms") to_pil = ToPILImage() -def save_all_transforms(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples): +def save_all_transforms( + cfg: ImageTransformsConfig, original_frame, output_dir, n_examples +): output_dir_all = output_dir / "all" output_dir_all.mkdir(parents=True, exist_ok=True) @@ -60,7 +62,9 @@ def save_all_transforms(cfg: ImageTransformsConfig, original_frame, output_dir, print(f" {output_dir_all}") -def save_each_transform(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples): +def save_each_transform( + cfg: ImageTransformsConfig, original_frame, output_dir, n_examples +): if not cfg.enable: logging.warning( "No single transforms will be saved, because `image_transforms.enable=False`. To enable, set `enable` to True in `ImageTransformsConfig` or in the command line with `--image_transforms.enable=True`." @@ -89,9 +93,15 @@ def save_each_transform(cfg: ImageTransformsConfig, original_frame, output_dir, tf_cfg_kwgs_max[key] = [max_, max_] tf_cfg_kwgs_avg[key] = [avg, avg] - tf_min = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_min})) - tf_max = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_max})) - tf_avg = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_avg})) + tf_min = make_transform_from_config( + replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_min}) + ) + tf_max = make_transform_from_config( + replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_max}) + ) + tf_avg = make_transform_from_config( + replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_avg}) + ) tf_frame_min = tf_min(original_frame) tf_frame_max = tf_max(original_frame) @@ -105,7 +115,9 @@ def save_each_transform(cfg: ImageTransformsConfig, original_frame, output_dir, @draccus.wrap() -def visualize_image_transforms(cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR, n_examples: int = 5): +def visualize_image_transforms( + cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR, n_examples: int = 5 +): dataset = LeRobotDataset( repo_id=cfg.repo_id, episodes=cfg.episodes, diff --git a/tests/artifacts/policies/save_policy_to_safetensors.py b/tests/artifacts/policies/save_policy_to_safetensors.py index 106f0dc0..c1465074 100644 --- a/tests/artifacts/policies/save_policy_to_safetensors.py +++ b/tests/artifacts/policies/save_policy_to_safetensors.py @@ -51,7 +51,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): batch = next(iter(dataloader)) loss, output_dict = policy.forward(batch) if output_dict is not None: - output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)} + output_dict = { + k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor) + } output_dict["loss"] = loss else: output_dict = {"loss": loss} @@ -69,7 +71,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): param_stats = {} for key, param in policy.named_parameters(): param_stats[f"{key}_mean"] = param.mean() - param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0)) + param_stats[f"{key}_std"] = ( + param.std() if param.numel() > 1 else torch.tensor(float(0.0)) + ) optimizer.zero_grad() policy.reset() @@ -96,11 +100,15 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): else: actions_queue = train_cfg.policy.n_action_repeats - actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)} + actions = { + str(i): policy.select_action(obs).contiguous() for i in range(actions_queue) + } return output_dict, grad_stats, param_stats, actions -def save_policy_to_safetensors(output_dir: Path, ds_repo_id: str, policy_name: str, policy_kwargs: dict): +def save_policy_to_safetensors( + output_dir: Path, ds_repo_id: str, policy_name: str, policy_kwargs: dict +): if output_dir.exists(): print(f"Overwrite existing safetensors in '{output_dir}':") print(f" - Validate with: `git add {output_dir}`") @@ -108,7 +116,9 @@ def save_policy_to_safetensors(output_dir: Path, ds_repo_id: str, policy_name: s shutil.rmtree(output_dir) output_dir.mkdir(parents=True, exist_ok=True) - output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs) + output_dict, grad_stats, param_stats, actions = get_policy_stats( + ds_repo_id, policy_name, policy_kwargs + ) save_file(output_dict, output_dir / "output_dict.safetensors") save_file(grad_stats, output_dir / "grad_stats.safetensors") save_file(param_stats, output_dir / "param_stats.safetensors") @@ -141,5 +151,7 @@ if __name__ == "__main__": raise RuntimeError("No policies were provided!") for ds_repo_id, policy, policy_kwargs, file_name_extra in artifacts_cfg: ds_name = ds_repo_id.split("/")[-1] - output_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy}_{file_name_extra}" + output_dir = ( + Path("tests/artifacts/policies") / f"{ds_name}_{policy}_{file_name_extra}" + ) save_policy_to_safetensors(output_dir, ds_repo_id, policy, policy_kwargs) diff --git a/tests/cameras/test_cameras.py b/tests/cameras/test_cameras.py index 971ac4e0..7ead111b 100644 --- a/tests/cameras/test_cameras.py +++ b/tests/cameras/test_cameras.py @@ -226,7 +226,13 @@ def test_save_images_from_cameras(tmp_path, request, camera_type, mock): @pytest.mark.parametrize("camera_type, mock", TEST_CAMERA_TYPES) @require_camera def test_camera_rotation(request, camera_type, mock): - config_kwargs = {"camera_type": camera_type, "mock": mock, "width": 640, "height": 480, "fps": 30} + config_kwargs = { + "camera_type": camera_type, + "mock": mock, + "width": 640, + "height": 480, + "fps": 30, + } # No rotation. camera = make_camera(**config_kwargs, rotation=None) diff --git a/tests/configs/test_plugin_loading.py b/tests/configs/test_plugin_loading.py index 1a8cceed..0b990fa3 100644 --- a/tests/configs/test_plugin_loading.py +++ b/tests/configs/test_plugin_loading.py @@ -9,7 +9,9 @@ from lerobot.common.envs.configs import EnvConfig from lerobot.configs.parser import PluginLoadError, load_plugin, parse_plugin_args, wrap -def create_plugin_code(*, base_class: str = "EnvConfig", plugin_name: str = "test_env") -> str: +def create_plugin_code( + *, base_class: str = "EnvConfig", plugin_name: str = "test_env" +) -> str: """Creates a dummy plugin module that implements its own EnvConfig subclass.""" return f""" from dataclasses import dataclass diff --git a/tests/datasets/test_compute_stats.py b/tests/datasets/test_compute_stats.py index d9032c8a..dbedd35c 100644 --- a/tests/datasets/test_compute_stats.py +++ b/tests/datasets/test_compute_stats.py @@ -31,7 +31,11 @@ from lerobot.common.datasets.compute_stats import ( def mock_load_image_as_numpy(path, dtype, channel_first): - return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype) + return ( + np.ones((3, 32, 32), dtype=dtype) + if channel_first + else np.ones((32, 32, 3), dtype=dtype) + ) @pytest.fixture @@ -61,7 +65,10 @@ def test_sample_indices(): assert len(indices) == estimate_num_samples(10) -@patch("lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy) +@patch( + "lerobot.common.datasets.compute_stats.load_image_as_numpy", + side_effect=mock_load_image_as_numpy, +) def test_sample_images(mock_load): image_paths = [f"image_{i}.jpg" for i in range(100)] images = sample_images(image_paths) @@ -74,9 +81,20 @@ def test_sample_images(mock_load): def test_get_feature_stats_images(): data = np.random.rand(100, 3, 32, 32) stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True) - assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats + assert ( + "min" in stats + and "max" in stats + and "mean" in stats + and "std" in stats + and "count" in stats + ) np.testing.assert_equal(stats["count"], np.array([100])) - assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape + assert ( + stats["min"].shape + == stats["max"].shape + == stats["mean"].shape + == stats["std"].shape + ) def test_get_feature_stats_axis_0_keepdims(sample_array): @@ -145,7 +163,8 @@ def test_compute_episode_stats(): } with patch( - "lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy + "lerobot.common.datasets.compute_stats.load_image_as_numpy", + side_effect=mock_load_image_as_numpy, ): stats = compute_episode_stats(episode_data, features) @@ -233,7 +252,13 @@ def test_aggregate_stats(): "std": [2.87, 5.87, 8.87], "count": 10, }, - "observation.state": {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10}, + "observation.state": { + "min": 1, + "max": 10, + "mean": 5.5, + "std": 2.87, + "count": 10, + }, "extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6}, }, { @@ -244,7 +269,13 @@ def test_aggregate_stats(): "std": [3.42, 2.42, 1.42], "count": 15, }, - "observation.state": {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15}, + "observation.state": { + "min": 2, + "max": 15, + "mean": 8.5, + "std": 3.42, + "count": 15, + }, "extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5}, }, ] @@ -284,28 +315,47 @@ def test_aggregate_stats(): for ep_stats in all_stats: for fkey, stats in ep_stats.items(): for k in stats: - stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32) + stats[k] = np.array( + stats[k], dtype=np.int64 if k == "count" else np.float32 + ) if fkey == "observation.image" and k != "count": - stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels + stats[k] = stats[k].reshape( + 3, 1, 1 + ) # for normalization on image channels else: stats[k] = stats[k].reshape(1) # cast to numpy for fkey, stats in expected_agg_stats.items(): for k in stats: - stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32) + stats[k] = np.array( + stats[k], dtype=np.int64 if k == "count" else np.float32 + ) if fkey == "observation.image" and k != "count": - stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels + stats[k] = stats[k].reshape( + 3, 1, 1 + ) # for normalization on image channels else: stats[k] = stats[k].reshape(1) results = aggregate_stats(all_stats) for fkey in expected_agg_stats: - np.testing.assert_allclose(results[fkey]["min"], expected_agg_stats[fkey]["min"]) - np.testing.assert_allclose(results[fkey]["max"], expected_agg_stats[fkey]["max"]) - np.testing.assert_allclose(results[fkey]["mean"], expected_agg_stats[fkey]["mean"]) np.testing.assert_allclose( - results[fkey]["std"], expected_agg_stats[fkey]["std"], atol=1e-04, rtol=1e-04 + results[fkey]["min"], expected_agg_stats[fkey]["min"] + ) + np.testing.assert_allclose( + results[fkey]["max"], expected_agg_stats[fkey]["max"] + ) + np.testing.assert_allclose( + results[fkey]["mean"], expected_agg_stats[fkey]["mean"] + ) + np.testing.assert_allclose( + results[fkey]["std"], + expected_agg_stats[fkey]["std"], + atol=1e-04, + rtol=1e-04, + ) + np.testing.assert_allclose( + results[fkey]["count"], expected_agg_stats[fkey]["count"] ) - np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"]) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 81447089..571b1fcc 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -72,7 +72,9 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): # Instantiate both ways robot = make_robot("koch", mock=True) root_create = tmp_path / "create" - dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create) + dataset_create = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create + ) root_init = tmp_path / "init" dataset_init = lerobot_dataset_factory(root=root_init) @@ -104,7 +106,8 @@ def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) with pytest.raises( - ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n" + ValueError, + match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n", ): dataset.add_frame({"state": torch.randn(1)}) @@ -113,7 +116,8 @@ def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) with pytest.raises( - ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n" + ValueError, + match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n", ): dataset.add_frame({"task": "Dummy task"}) @@ -122,18 +126,24 @@ def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) with pytest.raises( - ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n" + ValueError, + match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n", ): - dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"}) + dataset.add_frame( + {"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"} + ) def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) with pytest.raises( - ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n" + ValueError, + match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n", ): - dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"}) + dataset.add_frame( + {"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"} + ) def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory): @@ -141,7 +151,9 @@ def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) with pytest.raises( ValueError, - match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"), + match=re.escape( + "The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n" + ), ): dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"}) @@ -163,7 +175,9 @@ def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_fact dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) with pytest.raises( ValueError, - match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"), + match=re.escape( + "The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n" + ), ): dataset.add_frame({"state": torch.tensor(1.0), "task": "Dummy task"}) @@ -457,7 +471,9 @@ def test_flatten_unflatten_dict(): d = unflatten_dict(flatten_dict(d)) # test equality between nested dicts - assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}" + assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), ( + f"{original_d} != {d}" + ) @pytest.mark.parametrize( @@ -511,7 +527,13 @@ def test_backward_compatibility(repo_id): load_and_compare(i + 1) # test 2 frames at the middle of first episode - i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2) + i = int( + ( + dataset.episode_data_index["to"][0].item() + - dataset.episode_data_index["from"][0].item() + ) + / 2 + ) load_and_compare(i) load_and_compare(i + 1) diff --git a/tests/datasets/test_delta_timestamps.py b/tests/datasets/test_delta_timestamps.py index 35014642..62875116 100644 --- a/tests/datasets/test_delta_timestamps.py +++ b/tests/datasets/test_delta_timestamps.py @@ -54,7 +54,9 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n @pytest.fixture(scope="module") def synced_timestamps_factory(hf_dataset_factory): - def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + def _create_synced_timestamps( + fps: int = 30, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: hf_dataset = hf_dataset_factory(fps=fps) timestamps = torch.stack(hf_dataset["timestamp"]).numpy() episode_indices = torch.stack(hf_dataset["episode_index"]).numpy() @@ -69,8 +71,12 @@ def unsynced_timestamps_factory(synced_timestamps_factory): def _create_unsynced_timestamps( fps: int = 30, tolerance_s: float = 1e-4 ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps) - timestamps[30] += tolerance_s * 1.1 # Modify a single timestamp just outside tolerance + timestamps, episode_indices, episode_data_index = synced_timestamps_factory( + fps=fps + ) + timestamps[30] += ( + tolerance_s * 1.1 + ) # Modify a single timestamp just outside tolerance return timestamps, episode_indices, episode_data_index return _create_unsynced_timestamps @@ -81,8 +87,12 @@ def slightly_off_timestamps_factory(synced_timestamps_factory): def _create_slightly_off_timestamps( fps: int = 30, tolerance_s: float = 1e-4 ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps) - timestamps[30] += tolerance_s * 0.9 # Modify a single timestamp just inside tolerance + timestamps, episode_indices, episode_data_index = synced_timestamps_factory( + fps=fps + ) + timestamps[30] += ( + tolerance_s * 0.9 + ) # Modify a single timestamp just inside tolerance return timestamps, episode_indices, episode_data_index return _create_slightly_off_timestamps @@ -91,9 +101,13 @@ def slightly_off_timestamps_factory(synced_timestamps_factory): @pytest.fixture(scope="module") def valid_delta_timestamps_factory(): def _create_valid_delta_timestamps( - fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10) + fps: int = 30, + keys: list = DUMMY_MOTOR_FEATURES, + min_max_range: tuple[int, int] = (-10, 10), ) -> dict: - delta_timestamps = {key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys} + delta_timestamps = { + key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys + } return delta_timestamps return _create_valid_delta_timestamps @@ -130,7 +144,9 @@ def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory): @pytest.fixture(scope="module") def delta_indices_factory(): - def _delta_indices(keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)) -> dict: + def _delta_indices( + keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10) + ) -> dict: return {key: list(range(*min_max_range)) for key in keys} return _delta_indices @@ -182,7 +198,9 @@ def test_check_timestamps_sync_unsynced_no_exception(unsynced_timestamps_factory def test_check_timestamps_sync_slightly_off(slightly_off_timestamps_factory): fps = 30 tolerance_s = 1e-4 - timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory(fps, tolerance_s) + timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory( + fps, tolerance_s + ) result = check_timestamps_sync( timestamps=timestamps, episode_indices=ep_idx, @@ -223,7 +241,9 @@ def test_check_delta_timestamps_valid(valid_delta_timestamps_factory): def test_check_delta_timestamps_slightly_off(slightly_off_delta_timestamps_factory): fps = 30 tolerance_s = 1e-4 - slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory(fps, tolerance_s) + slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory( + fps, tolerance_s + ) result = check_delta_timestamps( delta_timestamps=slightly_off_delta_timestamps, fps=fps, diff --git a/tests/datasets/test_image_transforms.py b/tests/datasets/test_image_transforms.py index 44f38b2e..1252d6a7 100644 --- a/tests/datasets/test_image_transforms.py +++ b/tests/datasets/test_image_transforms.py @@ -33,7 +33,9 @@ from lerobot.scripts.visualize_image_transforms import ( save_all_transforms, save_each_transform, ) -from tests.artifacts.image_transforms.save_image_transforms_to_safetensors import ARTIFACT_DIR +from tests.artifacts.image_transforms.save_image_transforms_to_safetensors import ( + ARTIFACT_DIR, +) from tests.utils import require_x86_64_kernel @@ -80,7 +82,11 @@ def test_get_image_transforms_brightness(img_tensor_factory, min_max): img_tensor = img_tensor_factory() tf_cfg = ImageTransformsConfig( enable=True, - tfs={"brightness": ImageTransformConfig(type="ColorJitter", kwargs={"brightness": min_max})}, + tfs={ + "brightness": ImageTransformConfig( + type="ColorJitter", kwargs={"brightness": min_max} + ) + }, ) tf_actual = ImageTransforms(tf_cfg) tf_expected = v2.ColorJitter(brightness=min_max) @@ -91,7 +97,12 @@ def test_get_image_transforms_brightness(img_tensor_factory, min_max): def test_get_image_transforms_contrast(img_tensor_factory, min_max): img_tensor = img_tensor_factory() tf_cfg = ImageTransformsConfig( - enable=True, tfs={"contrast": ImageTransformConfig(type="ColorJitter", kwargs={"contrast": min_max})} + enable=True, + tfs={ + "contrast": ImageTransformConfig( + type="ColorJitter", kwargs={"contrast": min_max} + ) + }, ) tf_actual = ImageTransforms(tf_cfg) tf_expected = v2.ColorJitter(contrast=min_max) @@ -103,7 +114,11 @@ def test_get_image_transforms_saturation(img_tensor_factory, min_max): img_tensor = img_tensor_factory() tf_cfg = ImageTransformsConfig( enable=True, - tfs={"saturation": ImageTransformConfig(type="ColorJitter", kwargs={"saturation": min_max})}, + tfs={ + "saturation": ImageTransformConfig( + type="ColorJitter", kwargs={"saturation": min_max} + ) + }, ) tf_actual = ImageTransforms(tf_cfg) tf_expected = v2.ColorJitter(saturation=min_max) @@ -114,7 +129,8 @@ def test_get_image_transforms_saturation(img_tensor_factory, min_max): def test_get_image_transforms_hue(img_tensor_factory, min_max): img_tensor = img_tensor_factory() tf_cfg = ImageTransformsConfig( - enable=True, tfs={"hue": ImageTransformConfig(type="ColorJitter", kwargs={"hue": min_max})} + enable=True, + tfs={"hue": ImageTransformConfig(type="ColorJitter", kwargs={"hue": min_max})}, ) tf_actual = ImageTransforms(tf_cfg) tf_expected = v2.ColorJitter(hue=min_max) @@ -126,7 +142,11 @@ def test_get_image_transforms_sharpness(img_tensor_factory, min_max): img_tensor = img_tensor_factory() tf_cfg = ImageTransformsConfig( enable=True, - tfs={"sharpness": ImageTransformConfig(type="SharpnessJitter", kwargs={"sharpness": min_max})}, + tfs={ + "sharpness": ImageTransformConfig( + type="SharpnessJitter", kwargs={"sharpness": min_max} + ) + }, ) tf_actual = ImageTransforms(tf_cfg) tf_expected = SharpnessJitter(sharpness=min_max) @@ -342,7 +362,9 @@ def test_save_all_transforms(img_tensor_factory, tmp_path): # Check if the combined transforms directory exists and contains the right files combined_transforms_dir = tmp_path / "all" - assert combined_transforms_dir.exists(), "Combined transforms directory was not created." + assert combined_transforms_dir.exists(), ( + "Combined transforms directory was not created." + ) assert any(combined_transforms_dir.iterdir()), ( "No transformed images found in combined transforms directory." ) @@ -364,9 +386,9 @@ def test_save_each_transform(img_tensor_factory, tmp_path): for transform in transforms: transform_dir = tmp_path / transform assert transform_dir.exists(), f"{transform} directory was not created." - assert any( - transform_dir.iterdir() - ), f"No transformed images found in {transform} directory." + assert any(transform_dir.iterdir()), ( + f"No transformed images found in {transform} directory." + ) # Check for specific files within each transform directory expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + [ diff --git a/tests/datasets/test_online_buffer.py b/tests/datasets/test_online_buffer.py index 0285be1b..52cdf684 100644 --- a/tests/datasets/test_online_buffer.py +++ b/tests/datasets/test_online_buffer.py @@ -176,7 +176,9 @@ def test_delta_timestamps_within_tolerance(): buffer.tolerance_s = 0.04 item = buffer[2] data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"] - torch.testing.assert_close(data, torch.tensor([0, 2, 3]), msg="Data does not match expected values") + torch.testing.assert_close( + data, torch.tensor([0, 2, 3]), msg="Data does not match expected values" + ) assert not is_pad.any(), "Unexpected padding detected" @@ -212,7 +214,9 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range(): buffer.tolerance_s = 0.04 item = buffer[2] data, is_pad = item["index"], item["index_is_pad"] - assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values" + assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), ( + "Data does not match expected values" + ) assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), ( "Padding does not match expected values" ) @@ -275,7 +279,8 @@ def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_p online_sampling_ratio=online_sampling_ratio, ) torch.testing.assert_close( - weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]) + weights, + torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]), ) @@ -297,7 +302,8 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n( online_drop_n_last_frames=1, ) torch.testing.assert_close( - weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0]) + weights, + torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0]), ) @@ -318,4 +324,6 @@ def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp online_sampling_ratio=0.5, online_drop_n_last_frames=1, ) - torch.testing.assert_close(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0])) + torch.testing.assert_close( + weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0]) + ) diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index 0d02218a..0e50c73a 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -18,8 +18,13 @@ import torch from datasets import Dataset from huggingface_hub import DatasetCard -from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index -from lerobot.common.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch +from lerobot.common.datasets.push_dataset_to_hub.utils import ( + calculate_episode_data_index, +) +from lerobot.common.datasets.utils import ( + create_lerobot_dataset_card, + hf_transform_to_torch, +) def test_default_parameters(): diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index b7b63614..26a40ad9 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -210,7 +210,10 @@ def tasks_factory(): def _create_tasks(total_tasks: int = 3) -> int: tasks = {} for task_index in range(total_tasks): - task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."} + task_dict = { + "task_index": task_index, + "task": f"Perform action {task_index}.", + } tasks[task_index] = task_dict return tasks @@ -297,8 +300,12 @@ def hf_dataset_factory( episode_index_col = np.array([], dtype=np.int64) task_index = np.array([], dtype=np.int64) for ep_dict in episodes.values(): - timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps)) - frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int))) + timestamp_col = np.concatenate( + (timestamp_col, np.arange(ep_dict["length"]) / fps) + ) + frame_index_col = np.concatenate( + (frame_index_col, np.arange(ep_dict["length"], dtype=int)) + ) episode_index_col = np.concatenate( ( episode_index_col, @@ -385,7 +392,9 @@ def lerobot_dataset_metadata_factory( episodes=episodes, ) with ( - patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch, + patch( + "lerobot.common.datasets.lerobot_dataset.get_safe_version" + ) as mock_get_safe_version_patch, patch( "lerobot.common.datasets.lerobot_dataset.snapshot_download" ) as mock_snapshot_download_patch, @@ -433,7 +442,9 @@ def lerobot_dataset_factory( if not stats: stats = stats_factory(features=info["features"]) if not episodes_stats: - episodes_stats = episodes_stats_factory(features=info["features"], total_episodes=total_episodes) + episodes_stats = episodes_stats_factory( + features=info["features"], total_episodes=total_episodes + ) if not tasks: tasks = tasks_factory(total_tasks=info["total_tasks"]) if not episode_dicts: @@ -466,8 +477,12 @@ def lerobot_dataset_factory( episodes=episode_dicts, ) with ( - patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch, - patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch, + patch( + "lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata" + ) as mock_metadata_patch, + patch( + "lerobot.common.datasets.lerobot_dataset.get_safe_version" + ) as mock_get_safe_version_patch, patch( "lerobot.common.datasets.lerobot_dataset.snapshot_download" ) as mock_snapshot_download_patch, diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index d869586f..0ab33f84 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -59,7 +59,9 @@ def stats_path(stats_factory): @pytest.fixture(scope="session") def episodes_stats_path(episodes_stats_factory): - def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path: + def _create_episodes_stats_jsonl_file( + dir: Path, episodes_stats: list[dict] | None = None + ) -> Path: if not episodes_stats: episodes_stats = episodes_stats_factory() fpath = dir / EPISODES_STATS_PATH diff --git a/tests/fixtures/hub.py b/tests/fixtures/hub.py index 0bf4cc69..cb2195b1 100644 --- a/tests/fixtures/hub.py +++ b/tests/fixtures/hub.py @@ -99,7 +99,13 @@ def mock_snapshot_download_factory( # List all possible files all_files = [] - meta_files = [INFO_PATH, STATS_PATH, EPISODES_STATS_PATH, TASKS_PATH, EPISODES_PATH] + meta_files = [ + INFO_PATH, + STATS_PATH, + EPISODES_STATS_PATH, + TASKS_PATH, + EPISODES_PATH, + ] all_files.extend(meta_files) data_files = [] diff --git a/tests/fixtures/optimizers.py b/tests/fixtures/optimizers.py index 65488566..149f0255 100644 --- a/tests/fixtures/optimizers.py +++ b/tests/fixtures/optimizers.py @@ -35,5 +35,7 @@ def optimizer(model_params): @pytest.fixture def scheduler(optimizer): - config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5) + config = VQBeTSchedulerConfig( + num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5 + ) return config.build(optimizer, num_training_steps=100) diff --git a/tests/optim/test_schedulers.py b/tests/optim/test_schedulers.py index 17637663..e51558cd 100644 --- a/tests/optim/test_schedulers.py +++ b/tests/optim/test_schedulers.py @@ -43,7 +43,9 @@ def test_diffuser_scheduler(optimizer): def test_vqbet_scheduler(optimizer): - config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5) + config = VQBeTSchedulerConfig( + num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5 + ) scheduler = config.build(optimizer, num_training_steps=100) assert isinstance(scheduler, LambdaLR) diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index b3df477a..eb1c0003 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -59,16 +59,33 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p "action": { "dtype": "float32", "shape": (6,), - "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], + "names": [ + "shoulder_pan", + "shoulder_lift", + "elbow_flex", + "wrist_flex", + "wrist_roll", + "gripper", + ], }, "observation.state": { "dtype": "float32", "shape": (6,), - "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], + "names": [ + "shoulder_pan", + "shoulder_lift", + "elbow_flex", + "wrist_flex", + "wrist_roll", + "gripper", + ], }, } info = info_factory( - total_episodes=1, total_frames=1, camera_features=camera_features, motor_features=motor_features + total_episodes=1, + total_frames=1, + camera_features=camera_features, + motor_features=motor_features, ) ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info) return ds_meta @@ -81,7 +98,8 @@ def test_get_policy_and_config_classes(policy_name: str): policy_cfg = make_policy_config(policy_name) assert policy_cls.name == policy_name assert issubclass( - policy_cfg.__class__, inspect.signature(policy_cls.__init__).parameters["config"].annotation + policy_cfg.__class__, + inspect.signature(policy_cls.__init__).parameters["config"].annotation, ) @@ -92,7 +110,13 @@ def test_get_policy_and_config_classes(policy_name: str): ("lerobot/pusht", "pusht", {}, "diffusion", {}), ("lerobot/pusht", "pusht", {}, "vqbet", {}), ("lerobot/pusht", "pusht", {}, "act", {}), - ("lerobot/aloha_sim_insertion_human", "aloha", {"task": "AlohaInsertion-v0"}, "act", {}), + ( + "lerobot/aloha_sim_insertion_human", + "aloha", + {"task": "AlohaInsertion-v0"}, + "act", + {}, + ), ( "lerobot/aloha_sim_insertion_scripted", "aloha", @@ -172,11 +196,13 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs): # Test updating the policy (and test that it does not mutate the batch) batch_ = deepcopy(batch) policy.forward(batch) - assert set(batch) == set( - batch_ - ), "Batch keys are not the same after a forward pass." + assert set(batch) == set(batch_), ( + "Batch keys are not the same after a forward pass." + ) assert all( - torch.equal(batch[k], batch_[k]) if isinstance(batch[k], torch.Tensor) else batch[k] == batch_[k] + torch.equal(batch[k], batch_[k]) + if isinstance(batch[k], torch.Tensor) + else batch[k] == batch_[k] for k in batch ), "Batch values are not the same after a forward pass." @@ -215,8 +241,12 @@ def test_act_backbone_lr(): cfg = TrainPipelineConfig( # TODO(rcadene, aliberts): remove dataset download - dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]), - policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001), + dataset=DatasetConfig( + repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0] + ), + policy=make_policy_config( + "act", optimizer_lr=0.01, optimizer_lr_backbone=0.001 + ), ) cfg.validate() # Needed for auto-setting some parameters @@ -239,7 +269,9 @@ def test_policy_defaults(dummy_dataset_metadata, policy_name: str): policy_cls = get_policy_class(policy_name) policy_cfg = make_policy_config(policy_name) features = dataset_to_policy_features(dummy_dataset_metadata.features) - policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} + policy_cfg.output_features = { + key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION + } policy_cfg.input_features = { key: ft for key, ft in features.items() if key not in policy_cfg.output_features } @@ -251,7 +283,9 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name: policy_cls = get_policy_class(policy_name) policy_cfg = make_policy_config(policy_name) features = dataset_to_policy_features(dummy_dataset_metadata.features) - policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} + policy_cfg.output_features = { + key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION + } policy_cfg.input_features = { key: ft for key, ft in features.items() if key not in policy_cfg.output_features } @@ -260,7 +294,9 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name: save_dir = tmp_path / f"test_save_and_load_pretrained_{policy_cls.__name__}" policy.save_pretrained(save_dir) loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg) - torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0) + torch.testing.assert_close( + list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0 + ) @pytest.mark.parametrize("insert_temporal_dim", [False, True]) @@ -400,7 +436,9 @@ def test_normalize(insert_temporal_dim): # pass if it's run on another platform due to floating point errors @require_x86_64_kernel @require_cpu -def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs: dict, file_name_extra: str): +def test_backward_compatibility( + ds_repo_id: str, policy_name: str, policy_kwargs: dict, file_name_extra: str +): """ NOTE: If this test does not pass, and you have intentionally changed something in the policy: 1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should @@ -414,13 +452,17 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs 6. Remember to stage and commit the resulting changes to `tests/artifacts`. """ ds_name = ds_repo_id.split("/")[-1] - artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}" + artifact_dir = ( + Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}" + ) saved_output_dict = load_file(artifact_dir / "output_dict.safetensors") saved_grad_stats = load_file(artifact_dir / "grad_stats.safetensors") saved_param_stats = load_file(artifact_dir / "param_stats.safetensors") saved_actions = load_file(artifact_dir / "actions.safetensors") - output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs) + output_dict, grad_stats, param_stats, actions = get_policy_stats( + ds_repo_id, policy_name, policy_kwargs + ) for key in saved_output_dict: torch.testing.assert_close(output_dict[key], saved_output_dict[key]) @@ -429,8 +471,12 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs for key in saved_param_stats: torch.testing.assert_close(param_stats[key], saved_param_stats[key]) for key in saved_actions: - rtol, atol = (2e-3, 5e-6) if policy_name == "diffusion" else (None, None) # HACK - torch.testing.assert_close(actions[key], saved_actions[key], rtol=rtol, atol=atol) + rtol, atol = ( + (2e-3, 5e-6) if policy_name == "diffusion" else (None, None) + ) # HACK + torch.testing.assert_close( + actions[key], saved_actions[key], rtol=rtol, atol=atol + ) def test_act_temporal_ensembler(): diff --git a/tests/robots/test_control_robot.py b/tests/robots/test_control_robot.py index 059bb79f..6878a95a 100644 --- a/tests/robots/test_control_robot.py +++ b/tests/robots/test_control_robot.py @@ -180,7 +180,9 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock): assert dataset.meta.total_episodes == 2 assert len(dataset) == 2 - replay_cfg = ReplayControlConfig(episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False) + replay_cfg = ReplayControlConfig( + episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False + ) replay(robot, replay_cfg) policy_cfg = ACTConfig() @@ -335,12 +337,12 @@ def test_record_with_event_rerecord_episode(tmp_path, request, robot_type, mock) ) dataset = record(robot, rec_cfg) - assert not mock_events[ - "rerecord_episode" - ], "`rerecord_episode` wasn't properly reset to False" - assert not mock_events[ - "exit_early" - ], "`exit_early` wasn't properly reset to False" + assert not mock_events["rerecord_episode"], ( + "`rerecord_episode` wasn't properly reset to False" + ) + assert not mock_events["exit_early"], ( + "`exit_early` wasn't properly reset to False" + ) assert len(dataset) == 1, "`dataset` should contain only 1 frame" @@ -390,7 +392,9 @@ def test_record_with_event_exit_early(tmp_path, request, robot_type, mock): dataset = record(robot, rec_cfg) - assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False" + assert not mock_events["exit_early"], ( + "`exit_early` wasn't properly reset to False" + ) assert len(dataset) == 1, "`dataset` should contain only 1 frame" @@ -399,7 +403,9 @@ def test_record_with_event_exit_early(tmp_path, request, robot_type, mock): [("koch", True, 0), ("koch", True, 1)], ) @require_robot -def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, num_image_writer_processes): +def test_record_with_event_stop_recording( + tmp_path, request, robot_type, mock, num_image_writer_processes +): robot_kwargs = {"robot_type": robot_type, "mock": mock} if mock: @@ -445,5 +451,7 @@ def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, n dataset = record(robot, rec_cfg) - assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False" + assert not mock_events["exit_early"], ( + "`exit_early` wasn't properly reset to False" + ) assert len(dataset) == 1, "`dataset` should contain only 1 frame" diff --git a/tests/robots/test_robots.py b/tests/robots/test_robots.py index 4616c747..ba676bc3 100644 --- a/tests/robots/test_robots.py +++ b/tests/robots/test_robots.py @@ -40,7 +40,10 @@ import pytest import torch from lerobot.common.robot_devices.robots.utils import make_robot -from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError +from lerobot.common.robot_devices.utils import ( + RobotDeviceAlreadyConnectedError, + RobotDeviceNotConnectedError, +) from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot @@ -131,7 +134,9 @@ def test_robot(tmp_path, request, robot_type, mock): if "image" in name: # TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames continue - torch.testing.assert_close(captured_observation[name], observation[name], rtol=1e-4, atol=1) + torch.testing.assert_close( + captured_observation[name], observation[name], rtol=1e-4, atol=1 + ) assert captured_observation[name].shape == observation[name].shape # Test send_action can run diff --git a/tests/test_train_hilserl_classifier.py b/tests/test_train_hilserl_classifier.py index bc7a18bc..837eebc8 100644 --- a/tests/test_train_hilserl_classifier.py +++ b/tests/test_train_hilserl_classifier.py @@ -227,9 +227,9 @@ def test_resume_function( config_dir = os.path.abspath( os.path.join(test_file_dir, "..", "lerobot", "configs", "policy") ) - assert os.path.exists( - config_dir - ), f"Config directory does not exist at {config_dir}" + assert os.path.exists(config_dir), ( + f"Config directory does not exist at {config_dir}" + ) with initialize_config_dir( config_dir=config_dir, job_name="test_app", version_base="1.2" diff --git a/tests/utils.py b/tests/utils.py index ca4d89bf..23b297cb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -26,10 +26,16 @@ from lerobot import available_cameras, available_motors, available_robots from lerobot.common.robot_devices.cameras.utils import Camera from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device from lerobot.common.robot_devices.motors.utils import MotorsBus -from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device +from lerobot.common.robot_devices.motors.utils import ( + make_motors_bus as make_motors_bus_device, +) from lerobot.common.utils.import_utils import is_package_available -DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu" +DEVICE = ( + os.environ.get("LEROBOT_TEST_DEVICE", "cuda") + if torch.cuda.is_available() + else "cpu" +) TEST_ROBOT_TYPES = [] for robot_type in available_robots: @@ -45,7 +51,9 @@ for motor_type in available_motors: # Camera indices used for connecting physical cameras OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0)) -INTELREALSENSE_SERIAL_NUMBER = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614)) +INTELREALSENSE_SERIAL_NUMBER = int( + os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614) +) DYNAMIXEL_PORT = os.environ.get( "LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081" diff --git a/tests/utils/test_logging_utils.py b/tests/utils/test_logging_utils.py index 1ba1829e..f57a9eb7 100644 --- a/tests/utils/test_logging_utils.py +++ b/tests/utils/test_logging_utils.py @@ -18,7 +18,10 @@ from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker @pytest.fixture def mock_metrics(): - return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")} + return { + "loss": AverageMeter("loss", ":.3f"), + "accuracy": AverageMeter("accuracy", ":.2f"), + } def test_average_meter_initialization(): @@ -58,7 +61,11 @@ def test_average_meter_str(): def test_metrics_tracker_initialization(mock_metrics): tracker = MetricsTracker( - batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics, initial_step=10 + batch_size=32, + num_frames=1000, + num_episodes=50, + metrics=mock_metrics, + initial_step=10, ) assert tracker.steps == 10 assert tracker.samples == 10 * 32 @@ -70,7 +77,11 @@ def test_metrics_tracker_initialization(mock_metrics): def test_metrics_tracker_step(mock_metrics): tracker = MetricsTracker( - batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics, initial_step=5 + batch_size=32, + num_frames=1000, + num_episodes=50, + metrics=mock_metrics, + initial_step=5, ) tracker.step() assert tracker.steps == 6 @@ -80,7 +91,9 @@ def test_metrics_tracker_step(mock_metrics): def test_metrics_tracker_getattr(mock_metrics): - tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics) + tracker = MetricsTracker( + batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics + ) assert tracker.loss == mock_metrics["loss"] assert tracker.accuracy == mock_metrics["accuracy"] with pytest.raises(AttributeError): @@ -88,13 +101,17 @@ def test_metrics_tracker_getattr(mock_metrics): def test_metrics_tracker_setattr(mock_metrics): - tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics) + tracker = MetricsTracker( + batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics + ) tracker.loss = 2.0 assert tracker.loss.val == 2.0 def test_metrics_tracker_str(mock_metrics): - tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics) + tracker = MetricsTracker( + batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics + ) tracker.loss.update(3.456, 1) tracker.accuracy.update(0.876, 1) output = str(tracker) @@ -103,7 +120,9 @@ def test_metrics_tracker_str(mock_metrics): def test_metrics_tracker_to_dict(mock_metrics): - tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics) + tracker = MetricsTracker( + batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics + ) tracker.loss.update(5, 2) metrics_dict = tracker.to_dict() assert isinstance(metrics_dict, dict) @@ -112,7 +131,9 @@ def test_metrics_tracker_to_dict(mock_metrics): def test_metrics_tracker_reset_averages(mock_metrics): - tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics) + tracker = MetricsTracker( + batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics + ) tracker.loss.update(10, 3) tracker.accuracy.update(0.95, 5) tracker.reset_averages() diff --git a/tests/utils/test_random_utils.py b/tests/utils/test_random_utils.py index daf08a89..01df7341 100644 --- a/tests/utils/test_random_utils.py +++ b/tests/utils/test_random_utils.py @@ -118,5 +118,9 @@ def test_seeded_context(fixed_seed): seeded_val2 = (random.random(), np.random.rand(), torch.rand(1).item()) assert seeded_val1 == seeded_val2 - assert all(a != b for a, b in zip(val1, seeded_val1, strict=True)) # changed inside the context - assert all(a != b for a, b in zip(val2, seeded_val2, strict=True)) # changed again after exiting + assert all( + a != b for a, b in zip(val1, seeded_val1, strict=True) + ) # changed inside the context + assert all( + a != b for a, b in zip(val2, seeded_val2, strict=True) + ) # changed again after exiting diff --git a/tests/utils/test_train_utils.py b/tests/utils/test_train_utils.py index b78f6e49..1fc63eef 100644 --- a/tests/utils/test_train_utils.py +++ b/tests/utils/test_train_utils.py @@ -91,7 +91,9 @@ def test_save_training_state(tmp_path, optimizer, scheduler): def test_save_load_training_state(tmp_path, optimizer, scheduler): save_training_state(tmp_path, 10, optimizer, scheduler) - loaded_step, loaded_optimizer, loaded_scheduler = load_training_state(tmp_path, optimizer, scheduler) + loaded_step, loaded_optimizer, loaded_scheduler = load_training_state( + tmp_path, optimizer, scheduler + ) assert loaded_step == 10 assert loaded_optimizer is optimizer assert loaded_scheduler is scheduler