Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_10_dataset_v2.1

This commit is contained in:
Simon Alibert 2025-02-15 15:56:24 +01:00
commit aed3eb4a94
27 changed files with 134 additions and 84 deletions

View File

@ -8,6 +8,8 @@ on:
schedule: schedule:
- cron: "0 1 * * *" - cron: "0 1 * * *"
permissions: {}
env: env:
PYTHON_VERSION: "3.10" PYTHON_VERSION: "3.10"
@ -25,11 +27,14 @@ jobs:
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
with:
cache-binary: false
- name: Check out code - name: Check out code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
lfs: true lfs: true
persist-credentials: false
- name: Login to DockerHub - name: Login to DockerHub
uses: docker/login-action@v3 uses: docker/login-action@v3
@ -60,11 +65,14 @@ jobs:
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
with:
cache-binary: false
- name: Check out code - name: Check out code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
lfs: true lfs: true
persist-credentials: false
- name: Login to DockerHub - name: Login to DockerHub
uses: docker/login-action@v3 uses: docker/login-action@v3
@ -89,9 +97,13 @@ jobs:
steps: steps:
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
with:
cache-binary: false
- name: Check out code - name: Check out code
uses: actions/checkout@v4 uses: actions/checkout@v4
with:
persist-credentials: false
- name: Login to DockerHub - name: Login to DockerHub
uses: docker/login-action@v3 uses: docker/login-action@v3

View File

@ -7,6 +7,8 @@ on:
schedule: schedule:
- cron: "0 2 * * *" - cron: "0 2 * * *"
permissions: {}
# env: # env:
# SLACK_API_TOKEN: ${{ secrets.SLACK_API_TOKEN }} # SLACK_API_TOKEN: ${{ secrets.SLACK_API_TOKEN }}
jobs: jobs:

View File

@ -8,6 +8,8 @@ on:
branches: branches:
- main - main
permissions: {}
env: env:
PYTHON_VERSION: "3.10" PYTHON_VERSION: "3.10"
@ -17,7 +19,9 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout Repository - name: Checkout Repository
uses: actions/checkout@v3 uses: actions/checkout@v4
with:
persist-credentials: false
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v4 uses: actions/setup-python@v4
@ -45,7 +49,9 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout Repository - name: Checkout Repository
uses: actions/checkout@v3 uses: actions/checkout@v4
with:
persist-credentials: false
- name: Install poetry - name: Install poetry
run: pipx install "poetry<2.0.0" run: pipx install "poetry<2.0.0"
@ -59,7 +65,9 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout Repository - name: Checkout Repository
uses: actions/checkout@v3 uses: actions/checkout@v4
with:
persist-credentials: false
- name: Install poetry - name: Install poetry
run: pipx install "poetry<2.0.0" run: pipx install "poetry<2.0.0"

View File

@ -8,6 +8,8 @@ on:
# Run only when DockerFile files are modified # Run only when DockerFile files are modified
- "docker/**" - "docker/**"
permissions: {}
env: env:
PYTHON_VERSION: "3.10" PYTHON_VERSION: "3.10"
@ -20,6 +22,8 @@ jobs:
steps: steps:
- name: Check out code - name: Check out code
uses: actions/checkout@v4 uses: actions/checkout@v4
with:
persist-credentials: false
- name: Get changed files - name: Get changed files
id: changed-files id: changed-files
@ -34,7 +38,7 @@ jobs:
env: env:
ALL_CHANGED_FILES: ${{ steps.changed-files.outputs.all_changed_files }} ALL_CHANGED_FILES: ${{ steps.changed-files.outputs.all_changed_files }}
run: | run: |
echo "matrix=${{ steps.changed-files.outputs.all_changed_files}}" >> $GITHUB_OUTPUT echo "matrix=${ALL_CHANGED_FILES}" >> $GITHUB_OUTPUT
build_modified_dockerfiles: build_modified_dockerfiles:
@ -50,9 +54,13 @@ jobs:
steps: steps:
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
with:
cache-binary: false
- name: Check out code - name: Check out code
uses: actions/checkout@v4 uses: actions/checkout@v4
with:
persist-credentials: false
- name: Build Docker image - name: Build Docker image
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5

View File

@ -22,6 +22,8 @@ on:
- "Makefile" - "Makefile"
- ".cache/**" - ".cache/**"
permissions: {}
jobs: jobs:
pytest: pytest:
name: Pytest name: Pytest
@ -32,6 +34,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
lfs: true # Ensure LFS files are pulled lfs: true # Ensure LFS files are pulled
persist-credentials: false
- name: Install apt dependencies - name: Install apt dependencies
# portaudio19-dev is needed to install pyaudio # portaudio19-dev is needed to install pyaudio
@ -72,6 +75,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
lfs: true # Ensure LFS files are pulled lfs: true # Ensure LFS files are pulled
persist-credentials: false
- name: Install apt dependencies - name: Install apt dependencies
run: sudo apt-get update && sudo apt-get install -y ffmpeg run: sudo apt-get update && sudo apt-get install -y ffmpeg
@ -108,6 +112,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
lfs: true # Ensure LFS files are pulled lfs: true # Ensure LFS files are pulled
persist-credentials: false
- name: Install apt dependencies - name: Install apt dependencies
# portaudio19-dev is needed to install pyaudio # portaudio19-dev is needed to install pyaudio

View File

@ -3,8 +3,7 @@ on:
name: Secret Leaks name: Secret Leaks
permissions: permissions: {}
contents: read
jobs: jobs:
trufflehog: trufflehog:
@ -14,6 +13,8 @@ jobs:
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
persist-credentials: false
- name: Secret Scanning - name: Secret Scanning
uses: trufflesecurity/trufflehog@main uses: trufflesecurity/trufflehog@main
with: with:

View File

@ -14,17 +14,17 @@ repos:
- id: end-of-file-fixer - id: end-of-file-fixer
- id: trailing-whitespace - id: trailing-whitespace
- repo: https://github.com/asottile/pyupgrade - repo: https://github.com/asottile/pyupgrade
rev: v3.19.0 rev: v3.19.1
hooks: hooks:
- id: pyupgrade - id: pyupgrade
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.2 rev: v0.9.6
hooks: hooks:
- id: ruff - id: ruff
args: [--fix] args: [--fix]
- id: ruff-format - id: ruff-format
- repo: https://github.com/python-poetry/poetry - repo: https://github.com/python-poetry/poetry
rev: 1.8.0 rev: 1.8.5
hooks: hooks:
- id: poetry-check - id: poetry-check
- id: poetry-lock - id: poetry-lock
@ -32,6 +32,10 @@ repos:
- "--check" - "--check"
- "--no-update" - "--no-update"
- repo: https://github.com/gitleaks/gitleaks - repo: https://github.com/gitleaks/gitleaks
rev: v8.21.2 rev: v8.23.3
hooks: hooks:
- id: gitleaks - id: gitleaks
- repo: https://github.com/woodruffw/zizmor-pre-commit
rev: v1.3.1
hooks:
- id: zizmor

View File

@ -409,9 +409,9 @@ class ACT(nn.Module):
latent dimension. latent dimension.
""" """
if self.config.use_vae and self.training: if self.config.use_vae and self.training:
assert ( assert "action" in batch, (
"action" in batch "actions must be provided when using the variational objective in training mode."
), "actions must be provided when using the variational objective in training mode." )
batch_size = ( batch_size = (
batch["observation.images"] batch["observation.images"]

View File

@ -221,7 +221,7 @@ class DiffusionConfig(PreTrainedConfig):
for key, image_ft in self.image_features.items(): for key, image_ft in self.image_features.items():
if image_ft.shape != first_image_ft.shape: if image_ft.shape != first_image_ft.shape:
raise ValueError( raise ValueError(
f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match." f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
) )
@property @property

View File

@ -300,7 +300,7 @@ class PI0Policy(PreTrainedPolicy):
self._action_queue.extend(actions.transpose(0, 1)) self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft() return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
"""Do a full training forward pass to compute the loss""" """Do a full training forward pass to compute the loss"""
if self.config.adapt_to_pi_aloha: if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
@ -328,12 +328,12 @@ class PI0Policy(PreTrainedPolicy):
losses = losses[:, :, : self.config.max_action_dim] losses = losses[:, :, : self.config.max_action_dim]
loss_dict["losses_after_rm_padding"] = losses.clone() loss_dict["losses_after_rm_padding"] = losses.clone()
loss = losses.mean()
# For backward pass # For backward pass
loss_dict["loss"] = loss loss = losses.mean()
# For logging # For logging
loss_dict["l2_loss"] = loss.item() loss_dict["l2_loss"] = loss.item()
return loss_dict
return loss, loss_dict
def prepare_images(self, batch): def prepare_images(self, batch):
"""Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and """Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and

View File

@ -594,9 +594,9 @@ class TDMPCTOLD(nn.Module):
self.apply(_apply_fn) self.apply(_apply_fn)
for m in [self._reward, *self._Qs]: for m in [self._reward, *self._Qs]:
assert isinstance( assert isinstance(m[-1], nn.Linear), (
m[-1], nn.Linear "Sanity check. The last linear layer needs 0 initialization on weights."
), "Sanity check. The last linear layer needs 0 initialization on weights." )
nn.init.zeros_(m[-1].weight) nn.init.zeros_(m[-1].weight)
nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure

View File

@ -184,7 +184,7 @@ class VQBeTConfig(PreTrainedConfig):
for key, image_ft in self.image_features.items(): for key, image_ft in self.image_features.items():
if image_ft.shape != first_image_ft.shape: if image_ft.shape != first_image_ft.shape:
raise ValueError( raise ValueError(
f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match." f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
) )
@property @property

View File

@ -203,9 +203,9 @@ class GPT(nn.Module):
def forward(self, input, targets=None): def forward(self, input, targets=None):
device = input.device device = input.device
b, t, d = input.size() b, t, d = input.size()
assert ( assert t <= self.config.gpt_block_size, (
t <= self.config.gpt_block_size f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
), f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}" )
# positional encodings that are added to the input embeddings # positional encodings that are added to the input embeddings
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
@ -273,11 +273,11 @@ class GPT(nn.Module):
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format( assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
str(inter_params) str(inter_params)
) )
assert ( assert len(param_dict.keys() - union_params) == 0, (
len(param_dict.keys() - union_params) == 0 "parameters {} were not separated into either decay/no_decay set!".format(
), "parameters {} were not separated into either decay/no_decay set!".format(
str(param_dict.keys() - union_params), str(param_dict.keys() - union_params),
) )
)
decay = [param_dict[pn] for pn in sorted(decay)] decay = [param_dict[pn] for pn in sorted(decay)]
no_decay = [param_dict[pn] for pn in sorted(no_decay)] no_decay = [param_dict[pn] for pn in sorted(no_decay)]
@ -419,9 +419,9 @@ class ResidualVQ(nn.Module):
# and the network should be able to reconstruct # and the network should be able to reconstruct
if quantize_dim < self.num_quantizers: if quantize_dim < self.num_quantizers:
assert ( assert self.quantize_dropout > 0.0, (
self.quantize_dropout > 0.0 "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations" )
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1) indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1)
# get ready for gathering # get ready for gathering
@ -472,9 +472,9 @@ class ResidualVQ(nn.Module):
all_indices = [] all_indices = []
if return_loss: if return_loss:
assert not torch.any( assert not torch.any(indices == -1), (
indices == -1 "some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss"
), "some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss" )
ce_losses = [] ce_losses = []
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
@ -887,9 +887,9 @@ class VectorQuantize(nn.Module):
# only calculate orthogonal loss for the activated codes for this batch # only calculate orthogonal loss for the activated codes for this batch
if self.orthogonal_reg_active_codes_only: if self.orthogonal_reg_active_codes_only:
assert not ( assert not (is_multiheaded and self.separate_codebook_per_head), (
is_multiheaded and self.separate_codebook_per_head "orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet"
), "orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet" )
unique_code_ids = torch.unique(embed_ind) unique_code_ids = torch.unique(embed_ind)
codebook = codebook[:, unique_code_ids] codebook = codebook[:, unique_code_ids]
@ -999,9 +999,9 @@ def gumbel_sample(
ind = sampling_logits.argmax(dim=dim) ind = sampling_logits.argmax(dim=dim)
one_hot = F.one_hot(ind, size).type(dtype) one_hot = F.one_hot(ind, size).type(dtype)
assert not ( assert not (reinmax and not straight_through), (
reinmax and not straight_through "reinmax can only be turned on if using straight through gumbel softmax"
), "reinmax can only be turned on if using straight through gumbel softmax" )
if not straight_through or temperature <= 0.0 or not training: if not straight_through or temperature <= 0.0 or not training:
return ind, one_hot return ind, one_hot
@ -1209,9 +1209,9 @@ class EuclideanCodebook(nn.Module):
self.gumbel_sample = gumbel_sample self.gumbel_sample = gumbel_sample
self.sample_codebook_temp = sample_codebook_temp self.sample_codebook_temp = sample_codebook_temp
assert not ( assert not (use_ddp and num_codebooks > 1 and kmeans_init), (
use_ddp and num_codebooks > 1 and kmeans_init "kmeans init is not compatible with multiple codebooks in distributed environment for now"
), "kmeans init is not compatible with multiple codebooks in distributed environment for now" )
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop

View File

@ -58,7 +58,7 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
# Check that they have exactly the same set of keys. # Check that they have exactly the same set of keys.
if target.keys() != source.keys(): if target.keys() != source.keys():
raise ValueError( raise ValueError(
f"Dictionary keys do not match.\n" f"Expected: {target.keys()}, got: {source.keys()}" f"Dictionary keys do not match.\nExpected: {target.keys()}, got: {source.keys()}"
) )
# Recursively update each key. # Recursively update each key.

View File

@ -102,7 +102,7 @@ class WandBLogger:
self._wandb.log_artifact(artifact) self._wandb.log_artifact(artifact)
def log_dict(self, d: dict, step: int, mode: str = "train"): def log_dict(self, d: dict, step: int, mode: str = "train"):
if mode in {"train", "eval"}: if mode not in {"train", "eval"}:
raise ValueError(mode) raise ValueError(mode)
for k, v in d.items(): for k, v in d.items():
@ -114,7 +114,7 @@ class WandBLogger:
self._wandb.log({f"{mode}/{k}": v}, step=step) self._wandb.log({f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"): def log_video(self, video_path: str, step: int, mode: str = "train"):
if mode in {"train", "eval"}: if mode not in {"train", "eval"}:
raise ValueError(mode) raise ValueError(mode)
wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4") wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4")

View File

@ -85,6 +85,11 @@ class TrainPipelineConfig(HubMixin):
config_path = parser.parse_arg("config_path") config_path = parser.parse_arg("config_path")
if not config_path: if not config_path:
raise ValueError("A config_path is expected when resuming a run.") raise ValueError("A config_path is expected when resuming a run.")
if not Path(config_path).resolve().exists():
raise NotADirectoryError(
f"{config_path=} is expected to be a local path. "
"Resuming from the hub is not supported for now."
)
policy_path = Path(config_path).parent policy_path = Path(config_path).parent
self.policy.pretrained_path = policy_path self.policy.pretrained_path = policy_path
self.checkpoint_path = policy_path.parent self.checkpoint_path = policy_path.parent

View File

@ -151,7 +151,9 @@ def rollout(
if return_observations: if return_observations:
all_observations.append(deepcopy(observation)) all_observations.append(deepcopy(observation))
observation = {key: observation[key].to(device, non_blocking=True) for key in observation} observation = {
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
}
with torch.inference_mode(): with torch.inference_mode():
action = policy.select_action(observation) action = policy.select_action(observation)

View File

@ -232,8 +232,10 @@ def train(cfg: TrainPipelineConfig):
if is_log_step: if is_log_step:
logging.info(train_tracker) logging.info(train_tracker)
if wandb_logger: if wandb_logger:
wandb_log_dict = {**train_tracker.to_dict(), **output_dict} wandb_log_dict = train_tracker.to_dict()
wandb_logger.log_dict(wandb_log_dict) if output_dict:
wandb_log_dict.update(output_dict)
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages() train_tracker.reset_averages()
if cfg.save_checkpoint and is_saving_step: if cfg.save_checkpoint and is_saving_step:
@ -271,6 +273,7 @@ def train(cfg: TrainPipelineConfig):
logging.info(eval_tracker) logging.info(eval_tracker)
if wandb_logger: if wandb_logger:
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info} wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval") wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")
if eval_env: if eval_env:

View File

@ -111,9 +111,9 @@ def visualize_dataset(
output_dir: Path | None = None, output_dir: Path | None = None,
) -> Path | None: ) -> Path | None:
if save: if save:
assert ( assert output_dir is not None, (
output_dir is not None "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`." )
repo_id = dataset.repo_id repo_id = dataset.repo_id

View File

@ -519,9 +519,9 @@ def test_backward_compatibility(repo_id):
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same" assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
for key in new_frame: for key in new_frame:
assert torch.isclose( assert torch.isclose(new_frame[key], old_frame[key]).all(), (
new_frame[key], old_frame[key] f"{key=} for index={i} does not contain the same value"
).all(), f"{key=} for index={i} does not contain the same value" )
# test2 first frames of first episode # test2 first frames of first episode
i = dataset.episode_data_index["from"][0].item() i = dataset.episode_data_index["from"][0].item()

View File

@ -343,13 +343,13 @@ def test_save_all_transforms(img_tensor_factory, tmp_path):
# Check if the combined transforms directory exists and contains the right files # Check if the combined transforms directory exists and contains the right files
combined_transforms_dir = tmp_path / "all" combined_transforms_dir = tmp_path / "all"
assert combined_transforms_dir.exists(), "Combined transforms directory was not created." assert combined_transforms_dir.exists(), "Combined transforms directory was not created."
assert any( assert any(combined_transforms_dir.iterdir()), (
combined_transforms_dir.iterdir() "No transformed images found in combined transforms directory."
), "No transformed images found in combined transforms directory." )
for i in range(1, n_examples + 1): for i in range(1, n_examples + 1):
assert ( assert (combined_transforms_dir / f"{i}.png").exists(), (
combined_transforms_dir / f"{i}.png" f"Combined transform image {i}.png was not found."
).exists(), f"Combined transform image {i}.png was not found." )
def test_save_each_transform(img_tensor_factory, tmp_path): def test_save_each_transform(img_tensor_factory, tmp_path):
@ -369,6 +369,6 @@ def test_save_each_transform(img_tensor_factory, tmp_path):
# Check for specific files within each transform directory # Check for specific files within each transform directory
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"] expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"]
for file_name in expected_files: for file_name in expected_files:
assert ( assert (transform_dir / file_name).exists(), (
transform_dir / file_name f"{file_name} was not found in {transform} directory."
).exists(), f"{file_name} was not found in {transform} directory." )

View File

@ -132,9 +132,9 @@ def test_fifo():
buffer.add_data(new_data) buffer.add_data(new_data)
n_more_episodes = 2 n_more_episodes = 2
# Developer sanity check (in case someone changes the global `buffer_capacity`). # Developer sanity check (in case someone changes the global `buffer_capacity`).
assert ( assert (n_episodes + n_more_episodes) * n_frames_per_episode > buffer_capacity, (
n_episodes + n_more_episodes "Something went wrong with the test code."
) * n_frames_per_episode > buffer_capacity, "Something went wrong with the test code." )
more_new_data = make_spoof_data_frames(n_more_episodes, n_frames_per_episode) more_new_data = make_spoof_data_frames(n_more_episodes, n_frames_per_episode)
buffer.add_data(more_new_data) buffer.add_data(more_new_data)
assert len(buffer) == buffer_capacity, "The buffer should be full." assert len(buffer) == buffer_capacity, "The buffer should be full."
@ -203,9 +203,9 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range():
item = buffer[2] item = buffer[2]
data, is_pad = item["index"], item["index_is_pad"] data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values" assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
assert torch.equal( assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), (
is_pad, torch.tensor([True, False, False, True, True]) "Padding does not match expected values"
), "Padding does not match expected values" )
# Arbitrarily set small dataset sizes, making sure to have uneven sizes. # Arbitrarily set small dataset sizes, making sure to have uneven sizes.

View File

@ -193,12 +193,12 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
observation_ = deepcopy(observation) observation_ = deepcopy(observation)
with torch.inference_mode(): with torch.inference_mode():
action = policy.select_action(observation).cpu().numpy() action = policy.select_action(observation).cpu().numpy()
assert set(observation) == set( assert set(observation) == set(observation_), (
observation_ "Observation batch keys are not the same after a forward pass."
), "Observation batch keys are not the same after a forward pass." )
assert all( assert all(torch.equal(observation[k], observation_[k]) for k in observation), (
torch.equal(observation[k], observation_[k]) for k in observation "Observation batch values are not the same after a forward pass."
), "Observation batch values are not the same after a forward pass." )
# Test step through policy # Test step through policy
env.step(action) env.step(action)