Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_10_dataset_v2.1

This commit is contained in:
Simon Alibert 2025-02-15 15:56:24 +01:00
commit aed3eb4a94
27 changed files with 134 additions and 84 deletions

View File

@ -8,6 +8,8 @@ on:
schedule:
- cron: "0 1 * * *"
permissions: {}
env:
PYTHON_VERSION: "3.10"
@ -25,11 +27,14 @@ jobs:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
cache-binary: false
- name: Check out code
uses: actions/checkout@v4
with:
lfs: true
persist-credentials: false
- name: Login to DockerHub
uses: docker/login-action@v3
@ -60,11 +65,14 @@ jobs:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
cache-binary: false
- name: Check out code
uses: actions/checkout@v4
with:
lfs: true
persist-credentials: false
- name: Login to DockerHub
uses: docker/login-action@v3
@ -89,9 +97,13 @@ jobs:
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
cache-binary: false
- name: Check out code
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Login to DockerHub
uses: docker/login-action@v3

View File

@ -7,6 +7,8 @@ on:
schedule:
- cron: "0 2 * * *"
permissions: {}
# env:
# SLACK_API_TOKEN: ${{ secrets.SLACK_API_TOKEN }}
jobs:

View File

@ -8,6 +8,8 @@ on:
branches:
- main
permissions: {}
env:
PYTHON_VERSION: "3.10"
@ -17,7 +19,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout Repository
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@v4
@ -45,7 +49,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout Repository
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Install poetry
run: pipx install "poetry<2.0.0"
@ -59,7 +65,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout Repository
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Install poetry
run: pipx install "poetry<2.0.0"

View File

@ -8,6 +8,8 @@ on:
# Run only when DockerFile files are modified
- "docker/**"
permissions: {}
env:
PYTHON_VERSION: "3.10"
@ -20,6 +22,8 @@ jobs:
steps:
- name: Check out code
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Get changed files
id: changed-files
@ -34,7 +38,7 @@ jobs:
env:
ALL_CHANGED_FILES: ${{ steps.changed-files.outputs.all_changed_files }}
run: |
echo "matrix=${{ steps.changed-files.outputs.all_changed_files}}" >> $GITHUB_OUTPUT
echo "matrix=${ALL_CHANGED_FILES}" >> $GITHUB_OUTPUT
build_modified_dockerfiles:
@ -50,9 +54,13 @@ jobs:
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
cache-binary: false
- name: Check out code
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Build Docker image
uses: docker/build-push-action@v5

View File

@ -22,6 +22,8 @@ on:
- "Makefile"
- ".cache/**"
permissions: {}
jobs:
pytest:
name: Pytest
@ -32,6 +34,7 @@ jobs:
- uses: actions/checkout@v4
with:
lfs: true # Ensure LFS files are pulled
persist-credentials: false
- name: Install apt dependencies
# portaudio19-dev is needed to install pyaudio
@ -72,6 +75,7 @@ jobs:
- uses: actions/checkout@v4
with:
lfs: true # Ensure LFS files are pulled
persist-credentials: false
- name: Install apt dependencies
run: sudo apt-get update && sudo apt-get install -y ffmpeg
@ -108,6 +112,7 @@ jobs:
- uses: actions/checkout@v4
with:
lfs: true # Ensure LFS files are pulled
persist-credentials: false
- name: Install apt dependencies
# portaudio19-dev is needed to install pyaudio

View File

@ -3,8 +3,7 @@ on:
name: Secret Leaks
permissions:
contents: read
permissions: {}
jobs:
trufflehog:
@ -14,6 +13,8 @@ jobs:
uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Secret Scanning
uses: trufflesecurity/trufflehog@main
with:

View File

@ -14,17 +14,17 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/asottile/pyupgrade
rev: v3.19.0
rev: v3.19.1
hooks:
- id: pyupgrade
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.2
rev: v0.9.6
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
- repo: https://github.com/python-poetry/poetry
rev: 1.8.0
rev: 1.8.5
hooks:
- id: poetry-check
- id: poetry-lock
@ -32,6 +32,10 @@ repos:
- "--check"
- "--no-update"
- repo: https://github.com/gitleaks/gitleaks
rev: v8.21.2
rev: v8.23.3
hooks:
- id: gitleaks
- repo: https://github.com/woodruffw/zizmor-pre-commit
rev: v1.3.1
hooks:
- id: zizmor

View File

@ -104,7 +104,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
)
logging.info(
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
f"{pformat(dataset.repo_id_to_index , indent=2)}"
f"{pformat(dataset.repo_id_to_index, indent=2)}"
)
if cfg.dataset.use_imagenet_stats:

View File

@ -72,7 +72,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
# However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps.
# are too far appart.
direction="nearest",
tolerance=pd.Timedelta(f"{1/fps} seconds"),
tolerance=pd.Timedelta(f"{1 / fps} seconds"),
)
# Remove rows with episode_index -1 which indicates data that correspond to in-between episodes
df = df[df["episode_index"] != -1]

View File

@ -409,9 +409,9 @@ class ACT(nn.Module):
latent dimension.
"""
if self.config.use_vae and self.training:
assert (
"action" in batch
), "actions must be provided when using the variational objective in training mode."
assert "action" in batch, (
"actions must be provided when using the variational objective in training mode."
)
batch_size = (
batch["observation.images"]

View File

@ -221,7 +221,7 @@ class DiffusionConfig(PreTrainedConfig):
for key, image_ft in self.image_features.items():
if image_ft.shape != first_image_ft.shape:
raise ValueError(
f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match."
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
)
@property

View File

@ -300,7 +300,7 @@ class PI0Policy(PreTrainedPolicy):
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
"""Do a full training forward pass to compute the loss"""
if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
@ -328,12 +328,12 @@ class PI0Policy(PreTrainedPolicy):
losses = losses[:, :, : self.config.max_action_dim]
loss_dict["losses_after_rm_padding"] = losses.clone()
loss = losses.mean()
# For backward pass
loss_dict["loss"] = loss
loss = losses.mean()
# For logging
loss_dict["l2_loss"] = loss.item()
return loss_dict
return loss, loss_dict
def prepare_images(self, batch):
"""Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and

View File

@ -594,9 +594,9 @@ class TDMPCTOLD(nn.Module):
self.apply(_apply_fn)
for m in [self._reward, *self._Qs]:
assert isinstance(
m[-1], nn.Linear
), "Sanity check. The last linear layer needs 0 initialization on weights."
assert isinstance(m[-1], nn.Linear), (
"Sanity check. The last linear layer needs 0 initialization on weights."
)
nn.init.zeros_(m[-1].weight)
nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure

View File

@ -184,7 +184,7 @@ class VQBeTConfig(PreTrainedConfig):
for key, image_ft in self.image_features.items():
if image_ft.shape != first_image_ft.shape:
raise ValueError(
f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match."
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
)
@property

View File

@ -203,9 +203,9 @@ class GPT(nn.Module):
def forward(self, input, targets=None):
device = input.device
b, t, d = input.size()
assert (
t <= self.config.gpt_block_size
), f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
assert t <= self.config.gpt_block_size, (
f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
)
# positional encodings that are added to the input embeddings
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
@ -273,10 +273,10 @@ class GPT(nn.Module):
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
str(inter_params)
)
assert (
len(param_dict.keys() - union_params) == 0
), "parameters {} were not separated into either decay/no_decay set!".format(
str(param_dict.keys() - union_params),
assert len(param_dict.keys() - union_params) == 0, (
"parameters {} were not separated into either decay/no_decay set!".format(
str(param_dict.keys() - union_params),
)
)
decay = [param_dict[pn] for pn in sorted(decay)]
@ -419,9 +419,9 @@ class ResidualVQ(nn.Module):
# and the network should be able to reconstruct
if quantize_dim < self.num_quantizers:
assert (
self.quantize_dropout > 0.0
), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
assert self.quantize_dropout > 0.0, (
"quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
)
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1)
# get ready for gathering
@ -472,9 +472,9 @@ class ResidualVQ(nn.Module):
all_indices = []
if return_loss:
assert not torch.any(
indices == -1
), "some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss"
assert not torch.any(indices == -1), (
"some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss"
)
ce_losses = []
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
@ -887,9 +887,9 @@ class VectorQuantize(nn.Module):
# only calculate orthogonal loss for the activated codes for this batch
if self.orthogonal_reg_active_codes_only:
assert not (
is_multiheaded and self.separate_codebook_per_head
), "orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet"
assert not (is_multiheaded and self.separate_codebook_per_head), (
"orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet"
)
unique_code_ids = torch.unique(embed_ind)
codebook = codebook[:, unique_code_ids]
@ -999,9 +999,9 @@ def gumbel_sample(
ind = sampling_logits.argmax(dim=dim)
one_hot = F.one_hot(ind, size).type(dtype)
assert not (
reinmax and not straight_through
), "reinmax can only be turned on if using straight through gumbel softmax"
assert not (reinmax and not straight_through), (
"reinmax can only be turned on if using straight through gumbel softmax"
)
if not straight_through or temperature <= 0.0 or not training:
return ind, one_hot
@ -1209,9 +1209,9 @@ class EuclideanCodebook(nn.Module):
self.gumbel_sample = gumbel_sample
self.sample_codebook_temp = sample_codebook_temp
assert not (
use_ddp and num_codebooks > 1 and kmeans_init
), "kmeans init is not compatible with multiple codebooks in distributed environment for now"
assert not (use_ddp and num_codebooks > 1 and kmeans_init), (
"kmeans init is not compatible with multiple codebooks in distributed environment for now"
)
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop

View File

@ -33,7 +33,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
def log_dt(shortname, dt_val_s):
nonlocal log_items, fps
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)"
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)"
if fps is not None:
actual_fps = 1 / dt_val_s
if actual_fps < fps - 1:

View File

@ -58,7 +58,7 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
# Check that they have exactly the same set of keys.
if target.keys() != source.keys():
raise ValueError(
f"Dictionary keys do not match.\n" f"Expected: {target.keys()}, got: {source.keys()}"
f"Dictionary keys do not match.\nExpected: {target.keys()}, got: {source.keys()}"
)
# Recursively update each key.

View File

@ -102,7 +102,7 @@ class WandBLogger:
self._wandb.log_artifact(artifact)
def log_dict(self, d: dict, step: int, mode: str = "train"):
if mode in {"train", "eval"}:
if mode not in {"train", "eval"}:
raise ValueError(mode)
for k, v in d.items():
@ -114,7 +114,7 @@ class WandBLogger:
self._wandb.log({f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"):
if mode in {"train", "eval"}:
if mode not in {"train", "eval"}:
raise ValueError(mode)
wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4")

View File

@ -85,6 +85,11 @@ class TrainPipelineConfig(HubMixin):
config_path = parser.parse_arg("config_path")
if not config_path:
raise ValueError("A config_path is expected when resuming a run.")
if not Path(config_path).resolve().exists():
raise NotADirectoryError(
f"{config_path=} is expected to be a local path. "
"Resuming from the hub is not supported for now."
)
policy_path = Path(config_path).parent
self.policy.pretrained_path = policy_path
self.checkpoint_path = policy_path.parent

View File

@ -151,7 +151,9 @@ def rollout(
if return_observations:
all_observations.append(deepcopy(observation))
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
observation = {
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
}
with torch.inference_mode():
action = policy.select_action(observation)

View File

@ -232,8 +232,10 @@ def train(cfg: TrainPipelineConfig):
if is_log_step:
logging.info(train_tracker)
if wandb_logger:
wandb_log_dict = {**train_tracker.to_dict(), **output_dict}
wandb_logger.log_dict(wandb_log_dict)
wandb_log_dict = train_tracker.to_dict()
if output_dict:
wandb_log_dict.update(output_dict)
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()
if cfg.save_checkpoint and is_saving_step:
@ -271,6 +273,7 @@ def train(cfg: TrainPipelineConfig):
logging.info(eval_tracker)
if wandb_logger:
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")
if eval_env:

View File

@ -111,9 +111,9 @@ def visualize_dataset(
output_dir: Path | None = None,
) -> Path | None:
if save:
assert (
output_dir is not None
), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
assert output_dir is not None, (
"Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
)
repo_id = dataset.repo_id

View File

@ -49,17 +49,17 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
# save 2 first frames of first episode
i = dataset.episode_data_index["from"][0].item()
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
# save 2 frames at the middle of first episode
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
# save 2 last frames of first episode
i = dataset.episode_data_index["to"][0].item()
save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
save_file(dataset[i - 2], repo_dir / f"frame_{i - 2}.safetensors")
save_file(dataset[i - 1], repo_dir / f"frame_{i - 1}.safetensors")
# TODO(rcadene): Enable testing on second and last episode
# We currently cant because our test dataset only contains the first episode

View File

@ -519,9 +519,9 @@ def test_backward_compatibility(repo_id):
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
for key in new_frame:
assert torch.isclose(
new_frame[key], old_frame[key]
).all(), f"{key=} for index={i} does not contain the same value"
assert torch.isclose(new_frame[key], old_frame[key]).all(), (
f"{key=} for index={i} does not contain the same value"
)
# test2 first frames of first episode
i = dataset.episode_data_index["from"][0].item()

View File

@ -343,13 +343,13 @@ def test_save_all_transforms(img_tensor_factory, tmp_path):
# Check if the combined transforms directory exists and contains the right files
combined_transforms_dir = tmp_path / "all"
assert combined_transforms_dir.exists(), "Combined transforms directory was not created."
assert any(
combined_transforms_dir.iterdir()
), "No transformed images found in combined transforms directory."
assert any(combined_transforms_dir.iterdir()), (
"No transformed images found in combined transforms directory."
)
for i in range(1, n_examples + 1):
assert (
combined_transforms_dir / f"{i}.png"
).exists(), f"Combined transform image {i}.png was not found."
assert (combined_transforms_dir / f"{i}.png").exists(), (
f"Combined transform image {i}.png was not found."
)
def test_save_each_transform(img_tensor_factory, tmp_path):
@ -369,6 +369,6 @@ def test_save_each_transform(img_tensor_factory, tmp_path):
# Check for specific files within each transform directory
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"]
for file_name in expected_files:
assert (
transform_dir / file_name
).exists(), f"{file_name} was not found in {transform} directory."
assert (transform_dir / file_name).exists(), (
f"{file_name} was not found in {transform} directory."
)

View File

@ -132,9 +132,9 @@ def test_fifo():
buffer.add_data(new_data)
n_more_episodes = 2
# Developer sanity check (in case someone changes the global `buffer_capacity`).
assert (
n_episodes + n_more_episodes
) * n_frames_per_episode > buffer_capacity, "Something went wrong with the test code."
assert (n_episodes + n_more_episodes) * n_frames_per_episode > buffer_capacity, (
"Something went wrong with the test code."
)
more_new_data = make_spoof_data_frames(n_more_episodes, n_frames_per_episode)
buffer.add_data(more_new_data)
assert len(buffer) == buffer_capacity, "The buffer should be full."
@ -203,9 +203,9 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range():
item = buffer[2]
data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
assert torch.equal(
is_pad, torch.tensor([True, False, False, True, True])
), "Padding does not match expected values"
assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), (
"Padding does not match expected values"
)
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.

View File

@ -193,12 +193,12 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
observation_ = deepcopy(observation)
with torch.inference_mode():
action = policy.select_action(observation).cpu().numpy()
assert set(observation) == set(
observation_
), "Observation batch keys are not the same after a forward pass."
assert all(
torch.equal(observation[k], observation_[k]) for k in observation
), "Observation batch values are not the same after a forward pass."
assert set(observation) == set(observation_), (
"Observation batch keys are not the same after a forward pass."
)
assert all(torch.equal(observation[k], observation_[k]) for k in observation), (
"Observation batch values are not the same after a forward pass."
)
# Test step through policy
env.step(action)