diff --git a/.github/workflows/build-docker-images.yml b/.github/workflows/build-docker-images.yml index f20de978..3c63fa11 100644 --- a/.github/workflows/build-docker-images.yml +++ b/.github/workflows/build-docker-images.yml @@ -8,6 +8,8 @@ on: schedule: - cron: "0 1 * * *" +permissions: {} + env: PYTHON_VERSION: "3.10" @@ -25,11 +27,14 @@ jobs: - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 + with: + cache-binary: false - name: Check out code uses: actions/checkout@v4 with: lfs: true + persist-credentials: false - name: Login to DockerHub uses: docker/login-action@v3 @@ -60,11 +65,14 @@ jobs: - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 + with: + cache-binary: false - name: Check out code uses: actions/checkout@v4 with: lfs: true + persist-credentials: false - name: Login to DockerHub uses: docker/login-action@v3 @@ -89,9 +97,13 @@ jobs: steps: - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 + with: + cache-binary: false - name: Check out code uses: actions/checkout@v4 + with: + persist-credentials: false - name: Login to DockerHub uses: docker/login-action@v3 diff --git a/.github/workflows/nightly-tests.yml b/.github/workflows/nightly-tests.yml index bbee19a1..210a690c 100644 --- a/.github/workflows/nightly-tests.yml +++ b/.github/workflows/nightly-tests.yml @@ -7,6 +7,8 @@ on: schedule: - cron: "0 2 * * *" +permissions: {} + # env: # SLACK_API_TOKEN: ${{ secrets.SLACK_API_TOKEN }} jobs: diff --git a/.github/workflows/quality.yml b/.github/workflows/quality.yml index 6acc901e..f26fc1ed 100644 --- a/.github/workflows/quality.yml +++ b/.github/workflows/quality.yml @@ -8,6 +8,8 @@ on: branches: - main +permissions: {} + env: PYTHON_VERSION: "3.10" @@ -17,7 +19,9 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout Repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 + with: + persist-credentials: false - name: Set up Python uses: actions/setup-python@v4 @@ -45,7 +49,9 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout Repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 + with: + persist-credentials: false - name: Install poetry run: pipx install "poetry<2.0.0" @@ -59,7 +65,9 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout Repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 + with: + persist-credentials: false - name: Install poetry run: pipx install "poetry<2.0.0" diff --git a/.github/workflows/test-docker-build.yml b/.github/workflows/test-docker-build.yml index 08a3ab08..0d95427f 100644 --- a/.github/workflows/test-docker-build.yml +++ b/.github/workflows/test-docker-build.yml @@ -8,6 +8,8 @@ on: # Run only when DockerFile files are modified - "docker/**" +permissions: {} + env: PYTHON_VERSION: "3.10" @@ -20,6 +22,8 @@ jobs: steps: - name: Check out code uses: actions/checkout@v4 + with: + persist-credentials: false - name: Get changed files id: changed-files @@ -34,7 +38,7 @@ jobs: env: ALL_CHANGED_FILES: ${{ steps.changed-files.outputs.all_changed_files }} run: | - echo "matrix=${{ steps.changed-files.outputs.all_changed_files}}" >> $GITHUB_OUTPUT + echo "matrix=${ALL_CHANGED_FILES}" >> $GITHUB_OUTPUT build_modified_dockerfiles: @@ -50,9 +54,13 @@ jobs: steps: - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 + with: + cache-binary: false - name: Check out code uses: actions/checkout@v4 + with: + persist-credentials: false - name: Build Docker image uses: docker/build-push-action@v5 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index bc567418..1b0853ad 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -22,6 +22,8 @@ on: - "Makefile" - ".cache/**" +permissions: {} + jobs: pytest: name: Pytest @@ -32,6 +34,7 @@ jobs: - uses: actions/checkout@v4 with: lfs: true # Ensure LFS files are pulled + persist-credentials: false - name: Install apt dependencies # portaudio19-dev is needed to install pyaudio @@ -72,6 +75,7 @@ jobs: - uses: actions/checkout@v4 with: lfs: true # Ensure LFS files are pulled + persist-credentials: false - name: Install apt dependencies run: sudo apt-get update && sudo apt-get install -y ffmpeg @@ -108,6 +112,7 @@ jobs: - uses: actions/checkout@v4 with: lfs: true # Ensure LFS files are pulled + persist-credentials: false - name: Install apt dependencies # portaudio19-dev is needed to install pyaudio diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yml index d1dddab7..487ccea5 100644 --- a/.github/workflows/trufflehog.yml +++ b/.github/workflows/trufflehog.yml @@ -3,8 +3,7 @@ on: name: Secret Leaks -permissions: - contents: read +permissions: {} jobs: trufflehog: @@ -14,6 +13,8 @@ jobs: uses: actions/checkout@v4 with: fetch-depth: 0 + persist-credentials: false + - name: Secret Scanning uses: trufflesecurity/trufflehog@main with: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 58eca320..e5d45e81 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,17 +14,17 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/asottile/pyupgrade - rev: v3.19.0 + rev: v3.19.1 hooks: - id: pyupgrade - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.2 + rev: v0.9.6 hooks: - id: ruff args: [--fix] - id: ruff-format - repo: https://github.com/python-poetry/poetry - rev: 1.8.0 + rev: 1.8.5 hooks: - id: poetry-check - id: poetry-lock @@ -32,6 +32,10 @@ repos: - "--check" - "--no-update" - repo: https://github.com/gitleaks/gitleaks - rev: v8.21.2 + rev: v8.23.3 hooks: - id: gitleaks + - repo: https://github.com/woodruffw/zizmor-pre-commit + rev: v1.3.1 + hooks: + - id: zizmor diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 58ff400e..95ba76b8 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -104,7 +104,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas ) logging.info( "Multiple datasets were provided. Applied the following index mapping to the provided datasets: " - f"{pformat(dataset.repo_id_to_index , indent=2)}" + f"{pformat(dataset.repo_id_to_index, indent=2)}" ) if cfg.dataset.use_imagenet_stats: diff --git a/lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py b/lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py index 95f9c007..4968e002 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py @@ -72,7 +72,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod # However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps. # are too far appart. direction="nearest", - tolerance=pd.Timedelta(f"{1/fps} seconds"), + tolerance=pd.Timedelta(f"{1 / fps} seconds"), ) # Remove rows with episode_index -1 which indicates data that correspond to in-between episodes df = df[df["episode_index"] != -1] diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 9a1036c3..f2b16a1e 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -409,9 +409,9 @@ class ACT(nn.Module): latent dimension. """ if self.config.use_vae and self.training: - assert ( - "action" in batch - ), "actions must be provided when using the variational objective in training mode." + assert "action" in batch, ( + "actions must be provided when using the variational objective in training mode." + ) batch_size = ( batch["observation.images"] diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index 31d5dc8b..d571e152 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -221,7 +221,7 @@ class DiffusionConfig(PreTrainedConfig): for key, image_ft in self.image_features.items(): if image_ft.shape != first_image_ft.shape: raise ValueError( - f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match." + f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match." ) @property diff --git a/lerobot/common/policies/pi0/modeling_pi0.py b/lerobot/common/policies/pi0/modeling_pi0.py index 90d1a14c..c8b12caf 100644 --- a/lerobot/common/policies/pi0/modeling_pi0.py +++ b/lerobot/common/policies/pi0/modeling_pi0.py @@ -300,7 +300,7 @@ class PI0Policy(PreTrainedPolicy): self._action_queue.extend(actions.transpose(0, 1)) return self._action_queue.popleft() - def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]: + def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]: """Do a full training forward pass to compute the loss""" if self.config.adapt_to_pi_aloha: batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) @@ -328,12 +328,12 @@ class PI0Policy(PreTrainedPolicy): losses = losses[:, :, : self.config.max_action_dim] loss_dict["losses_after_rm_padding"] = losses.clone() - loss = losses.mean() # For backward pass - loss_dict["loss"] = loss + loss = losses.mean() # For logging loss_dict["l2_loss"] = loss.item() - return loss_dict + + return loss, loss_dict def prepare_images(self, batch): """Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index c4f90b8d..0940f198 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -594,9 +594,9 @@ class TDMPCTOLD(nn.Module): self.apply(_apply_fn) for m in [self._reward, *self._Qs]: - assert isinstance( - m[-1], nn.Linear - ), "Sanity check. The last linear layer needs 0 initialization on weights." + assert isinstance(m[-1], nn.Linear), ( + "Sanity check. The last linear layer needs 0 initialization on weights." + ) nn.init.zeros_(m[-1].weight) nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure diff --git a/lerobot/common/policies/vqbet/configuration_vqbet.py b/lerobot/common/policies/vqbet/configuration_vqbet.py index 47007e82..59389d6e 100644 --- a/lerobot/common/policies/vqbet/configuration_vqbet.py +++ b/lerobot/common/policies/vqbet/configuration_vqbet.py @@ -184,7 +184,7 @@ class VQBeTConfig(PreTrainedConfig): for key, image_ft in self.image_features.items(): if image_ft.shape != first_image_ft.shape: raise ValueError( - f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match." + f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match." ) @property diff --git a/lerobot/common/policies/vqbet/vqbet_utils.py b/lerobot/common/policies/vqbet/vqbet_utils.py index 90a2cfda..a2bd2df3 100644 --- a/lerobot/common/policies/vqbet/vqbet_utils.py +++ b/lerobot/common/policies/vqbet/vqbet_utils.py @@ -203,9 +203,9 @@ class GPT(nn.Module): def forward(self, input, targets=None): device = input.device b, t, d = input.size() - assert ( - t <= self.config.gpt_block_size - ), f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}" + assert t <= self.config.gpt_block_size, ( + f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}" + ) # positional encodings that are added to the input embeddings pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) @@ -273,10 +273,10 @@ class GPT(nn.Module): assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format( str(inter_params) ) - assert ( - len(param_dict.keys() - union_params) == 0 - ), "parameters {} were not separated into either decay/no_decay set!".format( - str(param_dict.keys() - union_params), + assert len(param_dict.keys() - union_params) == 0, ( + "parameters {} were not separated into either decay/no_decay set!".format( + str(param_dict.keys() - union_params), + ) ) decay = [param_dict[pn] for pn in sorted(decay)] @@ -419,9 +419,9 @@ class ResidualVQ(nn.Module): # and the network should be able to reconstruct if quantize_dim < self.num_quantizers: - assert ( - self.quantize_dropout > 0.0 - ), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations" + assert self.quantize_dropout > 0.0, ( + "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations" + ) indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1) # get ready for gathering @@ -472,9 +472,9 @@ class ResidualVQ(nn.Module): all_indices = [] if return_loss: - assert not torch.any( - indices == -1 - ), "some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss" + assert not torch.any(indices == -1), ( + "some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss" + ) ce_losses = [] should_quantize_dropout = self.training and self.quantize_dropout and not return_loss @@ -887,9 +887,9 @@ class VectorQuantize(nn.Module): # only calculate orthogonal loss for the activated codes for this batch if self.orthogonal_reg_active_codes_only: - assert not ( - is_multiheaded and self.separate_codebook_per_head - ), "orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet" + assert not (is_multiheaded and self.separate_codebook_per_head), ( + "orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet" + ) unique_code_ids = torch.unique(embed_ind) codebook = codebook[:, unique_code_ids] @@ -999,9 +999,9 @@ def gumbel_sample( ind = sampling_logits.argmax(dim=dim) one_hot = F.one_hot(ind, size).type(dtype) - assert not ( - reinmax and not straight_through - ), "reinmax can only be turned on if using straight through gumbel softmax" + assert not (reinmax and not straight_through), ( + "reinmax can only be turned on if using straight through gumbel softmax" + ) if not straight_through or temperature <= 0.0 or not training: return ind, one_hot @@ -1209,9 +1209,9 @@ class EuclideanCodebook(nn.Module): self.gumbel_sample = gumbel_sample self.sample_codebook_temp = sample_codebook_temp - assert not ( - use_ddp and num_codebooks > 1 and kmeans_init - ), "kmeans init is not compatible with multiple codebooks in distributed environment for now" + assert not (use_ddp and num_codebooks > 1 and kmeans_init), ( + "kmeans init is not compatible with multiple codebooks in distributed environment for now" + ) self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 5dcafa69..6c97d0cb 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -33,7 +33,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f def log_dt(shortname, dt_val_s): nonlocal log_items, fps - info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)" + info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)" if fps is not None: actual_fps = 1 / dt_val_s if actual_fps < fps - 1: diff --git a/lerobot/common/utils/io_utils.py b/lerobot/common/utils/io_utils.py index 3fc405f7..da0be1c7 100644 --- a/lerobot/common/utils/io_utils.py +++ b/lerobot/common/utils/io_utils.py @@ -58,7 +58,7 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T: # Check that they have exactly the same set of keys. if target.keys() != source.keys(): raise ValueError( - f"Dictionary keys do not match.\n" f"Expected: {target.keys()}, got: {source.keys()}" + f"Dictionary keys do not match.\nExpected: {target.keys()}, got: {source.keys()}" ) # Recursively update each key. diff --git a/lerobot/common/utils/wandb_utils.py b/lerobot/common/utils/wandb_utils.py index 2ab3c3fd..9985b894 100644 --- a/lerobot/common/utils/wandb_utils.py +++ b/lerobot/common/utils/wandb_utils.py @@ -102,7 +102,7 @@ class WandBLogger: self._wandb.log_artifact(artifact) def log_dict(self, d: dict, step: int, mode: str = "train"): - if mode in {"train", "eval"}: + if mode not in {"train", "eval"}: raise ValueError(mode) for k, v in d.items(): @@ -114,7 +114,7 @@ class WandBLogger: self._wandb.log({f"{mode}/{k}": v}, step=step) def log_video(self, video_path: str, step: int, mode: str = "train"): - if mode in {"train", "eval"}: + if mode not in {"train", "eval"}: raise ValueError(mode) wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4") diff --git a/lerobot/configs/train.py b/lerobot/configs/train.py index 9d63339d..93f6e2a4 100644 --- a/lerobot/configs/train.py +++ b/lerobot/configs/train.py @@ -85,6 +85,11 @@ class TrainPipelineConfig(HubMixin): config_path = parser.parse_arg("config_path") if not config_path: raise ValueError("A config_path is expected when resuming a run.") + if not Path(config_path).resolve().exists(): + raise NotADirectoryError( + f"{config_path=} is expected to be a local path. " + "Resuming from the hub is not supported for now." + ) policy_path = Path(config_path).parent self.policy.pretrained_path = policy_path self.checkpoint_path = policy_path.parent diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 7318748f..a4f79afc 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -151,7 +151,9 @@ def rollout( if return_observations: all_observations.append(deepcopy(observation)) - observation = {key: observation[key].to(device, non_blocking=True) for key in observation} + observation = { + key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation + } with torch.inference_mode(): action = policy.select_action(observation) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index a840b33d..f3c57fe2 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -232,8 +232,10 @@ def train(cfg: TrainPipelineConfig): if is_log_step: logging.info(train_tracker) if wandb_logger: - wandb_log_dict = {**train_tracker.to_dict(), **output_dict} - wandb_logger.log_dict(wandb_log_dict) + wandb_log_dict = train_tracker.to_dict() + if output_dict: + wandb_log_dict.update(output_dict) + wandb_logger.log_dict(wandb_log_dict, step) train_tracker.reset_averages() if cfg.save_checkpoint and is_saving_step: @@ -271,6 +273,7 @@ def train(cfg: TrainPipelineConfig): logging.info(eval_tracker) if wandb_logger: wandb_log_dict = {**eval_tracker.to_dict(), **eval_info} + wandb_logger.log_dict(wandb_log_dict, step, mode="eval") wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval") if eval_env: diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index ca176407..626b0bde 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -111,9 +111,9 @@ def visualize_dataset( output_dir: Path | None = None, ) -> Path | None: if save: - assert ( - output_dir is not None - ), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`." + assert output_dir is not None, ( + "Set an output directory where to write .rrd files with `--output-dir path/to/directory`." + ) repo_id = dataset.repo_id diff --git a/tests/scripts/save_dataset_to_safetensors.py b/tests/scripts/save_dataset_to_safetensors.py index 84c8f169..3b77348c 100644 --- a/tests/scripts/save_dataset_to_safetensors.py +++ b/tests/scripts/save_dataset_to_safetensors.py @@ -49,17 +49,17 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"): # save 2 first frames of first episode i = dataset.episode_data_index["from"][0].item() save_file(dataset[i], repo_dir / f"frame_{i}.safetensors") - save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors") + save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors") # save 2 frames at the middle of first episode i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2) save_file(dataset[i], repo_dir / f"frame_{i}.safetensors") - save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors") + save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors") # save 2 last frames of first episode i = dataset.episode_data_index["to"][0].item() - save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors") - save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors") + save_file(dataset[i - 2], repo_dir / f"frame_{i - 2}.safetensors") + save_file(dataset[i - 1], repo_dir / f"frame_{i - 1}.safetensors") # TODO(rcadene): Enable testing on second and last episode # We currently cant because our test dataset only contains the first episode diff --git a/tests/test_datasets.py b/tests/test_datasets.py index b1df9b46..72fa5b50 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -519,9 +519,9 @@ def test_backward_compatibility(repo_id): assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same" for key in new_frame: - assert torch.isclose( - new_frame[key], old_frame[key] - ).all(), f"{key=} for index={i} does not contain the same value" + assert torch.isclose(new_frame[key], old_frame[key]).all(), ( + f"{key=} for index={i} does not contain the same value" + ) # test2 first frames of first episode i = dataset.episode_data_index["from"][0].item() diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py index c118018a..19bd77df 100644 --- a/tests/test_image_transforms.py +++ b/tests/test_image_transforms.py @@ -343,13 +343,13 @@ def test_save_all_transforms(img_tensor_factory, tmp_path): # Check if the combined transforms directory exists and contains the right files combined_transforms_dir = tmp_path / "all" assert combined_transforms_dir.exists(), "Combined transforms directory was not created." - assert any( - combined_transforms_dir.iterdir() - ), "No transformed images found in combined transforms directory." + assert any(combined_transforms_dir.iterdir()), ( + "No transformed images found in combined transforms directory." + ) for i in range(1, n_examples + 1): - assert ( - combined_transforms_dir / f"{i}.png" - ).exists(), f"Combined transform image {i}.png was not found." + assert (combined_transforms_dir / f"{i}.png").exists(), ( + f"Combined transform image {i}.png was not found." + ) def test_save_each_transform(img_tensor_factory, tmp_path): @@ -369,6 +369,6 @@ def test_save_each_transform(img_tensor_factory, tmp_path): # Check for specific files within each transform directory expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"] for file_name in expected_files: - assert ( - transform_dir / file_name - ).exists(), f"{file_name} was not found in {transform} directory." + assert (transform_dir / file_name).exists(), ( + f"{file_name} was not found in {transform} directory." + ) diff --git a/tests/test_online_buffer.py b/tests/test_online_buffer.py index 092cd3d0..db53808d 100644 --- a/tests/test_online_buffer.py +++ b/tests/test_online_buffer.py @@ -132,9 +132,9 @@ def test_fifo(): buffer.add_data(new_data) n_more_episodes = 2 # Developer sanity check (in case someone changes the global `buffer_capacity`). - assert ( - n_episodes + n_more_episodes - ) * n_frames_per_episode > buffer_capacity, "Something went wrong with the test code." + assert (n_episodes + n_more_episodes) * n_frames_per_episode > buffer_capacity, ( + "Something went wrong with the test code." + ) more_new_data = make_spoof_data_frames(n_more_episodes, n_frames_per_episode) buffer.add_data(more_new_data) assert len(buffer) == buffer_capacity, "The buffer should be full." @@ -203,9 +203,9 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range(): item = buffer[2] data, is_pad = item["index"], item["index_is_pad"] assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values" - assert torch.equal( - is_pad, torch.tensor([True, False, False, True, True]) - ), "Padding does not match expected values" + assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), ( + "Padding does not match expected values" + ) # Arbitrarily set small dataset sizes, making sure to have uneven sizes. diff --git a/tests/test_policies.py b/tests/test_policies.py index 4374157d..27cf49f8 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -193,12 +193,12 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs): observation_ = deepcopy(observation) with torch.inference_mode(): action = policy.select_action(observation).cpu().numpy() - assert set(observation) == set( - observation_ - ), "Observation batch keys are not the same after a forward pass." - assert all( - torch.equal(observation[k], observation_[k]) for k in observation - ), "Observation batch values are not the same after a forward pass." + assert set(observation) == set(observation_), ( + "Observation batch keys are not the same after a forward pass." + ) + assert all(torch.equal(observation[k], observation_[k]) for k in observation), ( + "Observation batch values are not the same after a forward pass." + ) # Test step through policy env.step(action)