Merge remote-tracking branch 'origin/2025_02_20_add_dexvla' into 2025_02_20_add_dexvla

This commit is contained in:
lesjie-wen 2025-03-20 13:33:24 +08:00
commit 575fc92e69
11 changed files with 64 additions and 38 deletions

View File

@ -41,7 +41,7 @@ jobs:
- name: Get changed files - name: Get changed files
id: changed-files id: changed-files
uses: tj-actions/changed-files@v44 uses: tj-actions/changed-files@3f54ebb830831fc121d3263c1857cfbdc310cdb9 #v42
with: with:
files: docker/** files: docker/**
json: "true" json: "true"

View File

@ -51,7 +51,7 @@ For a comprehensive list and documentation of these parameters, see the ffmpeg d
### Decoding parameters ### Decoding parameters
**Decoder** **Decoder**
We tested two video decoding backends from torchvision: We tested two video decoding backends from torchvision:
- `pyav` (default) - `pyav`
- `video_reader` (requires to build torchvision from source) - `video_reader` (requires to build torchvision from source)
**Requested timestamps** **Requested timestamps**

View File

@ -69,6 +69,7 @@ from lerobot.common.datasets.video_utils import (
VideoFrame, VideoFrame,
decode_video_frames, decode_video_frames,
encode_video_frames, encode_video_frames,
get_safe_default_codec,
get_video_info, get_video_info,
) )
from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.robots.utils import Robot
@ -462,7 +463,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
video files are already present on local disk, they won't be downloaded again. Defaults to video files are already present on local disk, they won't be downloaded again. Defaults to
True. True.
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec. video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision. You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
""" """
super().__init__() super().__init__()
@ -473,7 +474,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episodes = episodes self.episodes = episodes
self.tolerance_s = tolerance_s self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = video_backend if video_backend else "torchcodec" self.video_backend = video_backend if video_backend else get_safe_default_codec()
self.delta_indices = None self.delta_indices = None
# Unused attributes # Unused attributes
@ -1027,7 +1028,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.delta_timestamps = None obj.delta_timestamps = None
obj.delta_indices = None obj.delta_indices = None
obj.episode_data_index = None obj.episode_data_index = None
obj.video_backend = video_backend if video_backend is not None else "torchcodec" obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
return obj return obj

View File

@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib
import json import json
import logging import logging
import subprocess import subprocess
@ -27,14 +28,23 @@ import torch
import torchvision import torchvision
from datasets.features.features import register_feature from datasets.features.features import register_feature
from PIL import Image from PIL import Image
from torchcodec.decoders import VideoDecoder
def get_safe_default_codec():
if importlib.util.find_spec("torchcodec"):
return "torchcodec"
else:
logging.warning(
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
)
return "pyav"
def decode_video_frames( def decode_video_frames(
video_path: Path | str, video_path: Path | str,
timestamps: list[float], timestamps: list[float],
tolerance_s: float, tolerance_s: float,
backend: str = "torchcodec", backend: str | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Decodes video frames using the specified backend. Decodes video frames using the specified backend.
@ -43,13 +53,15 @@ def decode_video_frames(
video_path (Path): Path to the video file. video_path (Path): Path to the video file.
timestamps (list[float]): List of timestamps to extract frames. timestamps (list[float]): List of timestamps to extract frames.
tolerance_s (float): Allowed deviation in seconds for frame retrieval. tolerance_s (float): Allowed deviation in seconds for frame retrieval.
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec". backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav"..
Returns: Returns:
torch.Tensor: Decoded frames. torch.Tensor: Decoded frames.
Currently supports torchcodec on cpu and pyav. Currently supports torchcodec on cpu and pyav.
""" """
if backend is None:
backend = get_safe_default_codec()
if backend == "torchcodec": if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s) return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
elif backend in ["pyav", "video_reader"]: elif backend in ["pyav", "video_reader"]:
@ -173,6 +185,12 @@ def decode_video_frames_torchcodec(
and all subsequent frames until reaching the requested frame. The number of key frames in a video and all subsequent frames until reaching the requested frame. The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes. can be adjusted during encoding to take into account decoding time and video size in bytes.
""" """
if importlib.util.find_spec("torchcodec"):
from torchcodec.decoders import VideoDecoder
else:
raise ImportError("torchcodec is required but not available.")
# initialize video decoder # initialize video decoder
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate") decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
loaded_frames = [] loaded_frames = []

View File

@ -119,9 +119,7 @@ class ACTPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
if self.config.image_features: if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack( batch["observation.images"] = [batch[key] for key in self.config.image_features]
[batch[key] for key in self.config.image_features], dim=-4
)
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions # If we are doing temporal ensembling, do online updates where we keep track of the number of actions
# we are ensembling over. # we are ensembling over.
@ -149,9 +147,8 @@ class ACTPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
if self.config.image_features: if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack( batch["observation.images"] = [batch[key] for key in self.config.image_features]
[batch[key] for key in self.config.image_features], dim=-4
)
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
@ -413,11 +410,10 @@ class ACT(nn.Module):
"actions must be provided when using the variational objective in training mode." "actions must be provided when using the variational objective in training mode."
) )
batch_size = ( if "observation.images" in batch:
batch["observation.images"] batch_size = batch["observation.images"][0].shape[0]
if "observation.images" in batch else:
else batch["observation.environment_state"] batch_size = batch["observation.environment_state"].shape[0]
).shape[0]
# Prepare the latent for input to the transformer encoder. # Prepare the latent for input to the transformer encoder.
if self.config.use_vae and "action" in batch: if self.config.use_vae and "action" in batch:
@ -490,20 +486,21 @@ class ACT(nn.Module):
all_cam_features = [] all_cam_features = []
all_cam_pos_embeds = [] all_cam_pos_embeds = []
for cam_index in range(batch["observation.images"].shape[-4]): # For a list of images, the H and W may vary but H*W is constant.
cam_features = self.backbone(batch["observation.images"][:, cam_index])["feature_map"] for img in batch["observation.images"]:
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use cam_features = self.backbone(img)["feature_map"]
# buffer
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) cam_features = self.encoder_img_feat_input_proj(cam_features)
# Rearrange features to (sequence, batch, dim).
cam_features = einops.rearrange(cam_features, "b c h w -> (h w) b c")
cam_pos_embed = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c")
all_cam_features.append(cam_features) all_cam_features.append(cam_features)
all_cam_pos_embeds.append(cam_pos_embed) all_cam_pos_embeds.append(cam_pos_embed)
# Concatenate camera observation feature maps and positional embeddings along the width dimension,
# and move to (sequence, batch, dim). encoder_in_tokens.extend(torch.cat(all_cam_features, axis=0))
all_cam_features = torch.cat(all_cam_features, axis=-1) encoder_in_pos_embed.extend(torch.cat(all_cam_pos_embeds, axis=0))
encoder_in_tokens.extend(einops.rearrange(all_cam_features, "b c h w -> (h w) b c"))
all_cam_pos_embeds = torch.cat(all_cam_pos_embeds, axis=-1)
encoder_in_pos_embed.extend(einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c"))
# Stack all tokens along the sequence dimension. # Stack all tokens along the sequence dimension.
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0) encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)

View File

@ -474,7 +474,7 @@ class ManipulatorRobot:
# Used when record_data=True # Used when record_data=True
follower_goal_pos[name] = goal_pos follower_goal_pos[name] = goal_pos
goal_pos = goal_pos.numpy().astype(np.int32) goal_pos = goal_pos.numpy().astype(np.float32)
self.follower_arms[name].write("Goal_Position", goal_pos) self.follower_arms[name].write("Goal_Position", goal_pos)
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t
@ -596,7 +596,7 @@ class ManipulatorRobot:
action_sent.append(goal_pos) action_sent.append(goal_pos)
# Send goal position to each follower # Send goal position to each follower
goal_pos = goal_pos.numpy().astype(np.int32) goal_pos = goal_pos.numpy().astype(np.float32)
self.follower_arms[name].write("Goal_Position", goal_pos) self.follower_arms[name].write("Goal_Position", goal_pos)
return torch.cat(action_sent) return torch.cat(action_sent)

View File

@ -69,7 +69,13 @@ class WandBLogger:
os.environ["WANDB_SILENT"] = "True" os.environ["WANDB_SILENT"] = "True"
import wandb import wandb
wandb_run_id = get_wandb_run_id_from_filesystem(self.log_dir) if cfg.resume else None wandb_run_id = (
cfg.wandb.run_id
if cfg.wandb.run_id
else get_wandb_run_id_from_filesystem(self.log_dir)
if cfg.resume
else None
)
wandb.init( wandb.init(
id=wandb_run_id, id=wandb_run_id,
project=self.cfg.project, project=self.cfg.project,

View File

@ -20,6 +20,7 @@ from lerobot.common import (
policies, # noqa: F401 policies, # noqa: F401
) )
from lerobot.common.datasets.transforms import ImageTransformsConfig from lerobot.common.datasets.transforms import ImageTransformsConfig
from lerobot.common.datasets.video_utils import get_safe_default_codec
@dataclass @dataclass
@ -35,7 +36,7 @@ class DatasetConfig:
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig) image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
revision: str | None = None revision: str | None = None
use_imagenet_stats: bool = True use_imagenet_stats: bool = True
video_backend: str = "pyav" video_backend: str = field(default_factory=get_safe_default_codec)
@dataclass @dataclass
@ -46,6 +47,7 @@ class WandBConfig:
project: str = "lerobot" project: str = "lerobot"
entity: str | None = None entity: str | None = None
notes: str | None = None notes: str | None = None
run_id: str | None = None
@dataclass @dataclass

View File

@ -79,7 +79,9 @@ class TrainPipelineConfig(HubMixin):
# The entire train config is already loaded, we just need to get the checkpoint dir # The entire train config is already loaded, we just need to get the checkpoint dir
config_path = parser.parse_arg("config_path") config_path = parser.parse_arg("config_path")
if not config_path: if not config_path:
raise ValueError("A config_path is expected when resuming a run.") raise ValueError(
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
)
if not Path(config_path).resolve().exists(): if not Path(config_path).resolve().exists():
raise NotADirectoryError( raise NotADirectoryError(
f"{config_path=} is expected to be a local path. " f"{config_path=} is expected to be a local path. "

View File

@ -69,7 +69,7 @@ dependencies = [
"rerun-sdk>=0.21.0", "rerun-sdk>=0.21.0",
"termcolor>=2.4.0", "termcolor>=2.4.0",
"torch>=2.2.1", "torch>=2.2.1",
"torchcodec>=0.2.1", "torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l'))",
"torchvision>=0.21.0", "torchvision>=0.21.0",
"wandb>=0.16.3", "wandb>=0.16.3",
"zarr>=2.17.0", "zarr>=2.17.0",