Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots

This commit is contained in:
Simon Alibert 2025-03-20 14:48:19 +01:00
commit 6541982dff
8 changed files with 55 additions and 33 deletions

View File

@ -51,7 +51,7 @@ For a comprehensive list and documentation of these parameters, see the ffmpeg d
### Decoding parameters
**Decoder**
We tested two video decoding backends from torchvision:
- `pyav` (default)
- `pyav`
- `video_reader` (requires to build torchvision from source)
**Requested timestamps**

View File

@ -292,6 +292,11 @@ Steps:
- Scan for devices. All 12 motors should appear.
- Select the motors one by one and move the arm. Check that the graphical indicator near the top right shows the movement.
** There is a common issue with the Dynamixel XL430-W250 motors where the motors become undiscoverable after upgrading their firmware from Mac and Windows Dynamixel Wizard2 applications. When this occurs, it is required to do a firmware recovery (Select `DYNAMIXEL Firmware Recovery` and follow the prompts). There are two known workarounds to conduct this firmware reset:
1) Install the Dynamixel Wizard on a linux machine and complete the firmware recovery
2) Use the Dynamixel U2D2 in order to perform the reset with Windows or Mac. This U2D2 can be purchased [here](https://www.robotis.us/u2d2/).
For either solution, open DYNAMIXEL Wizard 2.0 and select the appropriate port. You will likely be unable to see the motor in the GUI at this time. Select `Firmware Recovery`, carefully choose the correct model, and wait for the process to complete. Finally, re-scan to confirm the firmware recovery was successful.
**Read and Write with DynamixelMotorsBus**
To get familiar with how `DynamixelMotorsBus` communicates with the motors, you can start by reading data from them. Copy past this code in the same interactive python session:

View File

@ -69,6 +69,7 @@ from lerobot.common.datasets.video_utils import (
VideoFrame,
decode_video_frames,
encode_video_frames,
get_safe_default_codec,
get_video_info,
)
from lerobot.common.robots.utils import Robot
@ -462,7 +463,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
video files are already present on local disk, they won't be downloaded again. Defaults to
True.
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec.
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
"""
super().__init__()
@ -473,7 +474,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episodes = episodes
self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = video_backend if video_backend else "torchcodec"
self.video_backend = video_backend if video_backend else get_safe_default_codec()
self.delta_indices = None
# Unused attributes
@ -1027,7 +1028,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.delta_timestamps = None
obj.delta_indices = None
obj.episode_data_index = None
obj.video_backend = video_backend if video_backend is not None else "torchcodec"
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
return obj

View File

@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import json
import logging
import subprocess
@ -27,14 +28,23 @@ import torch
import torchvision
from datasets.features.features import register_feature
from PIL import Image
from torchcodec.decoders import VideoDecoder
def get_safe_default_codec():
if importlib.util.find_spec("torchcodec"):
return "torchcodec"
else:
logging.warning(
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
)
return "pyav"
def decode_video_frames(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
backend: str = "torchcodec",
backend: str | None = None,
) -> torch.Tensor:
"""
Decodes video frames using the specified backend.
@ -43,13 +53,15 @@ def decode_video_frames(
video_path (Path): Path to the video file.
timestamps (list[float]): List of timestamps to extract frames.
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec".
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav"..
Returns:
torch.Tensor: Decoded frames.
Currently supports torchcodec on cpu and pyav.
"""
if backend is None:
backend = get_safe_default_codec()
if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
elif backend in ["pyav", "video_reader"]:
@ -173,6 +185,12 @@ def decode_video_frames_torchcodec(
and all subsequent frames until reaching the requested frame. The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes.
"""
if importlib.util.find_spec("torchcodec"):
from torchcodec.decoders import VideoDecoder
else:
raise ImportError("torchcodec is required but not available.")
# initialize video decoder
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
loaded_frames = []

View File

@ -119,9 +119,7 @@ class ACTPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
batch["observation.images"] = [batch[key] for key in self.config.image_features]
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
# we are ensembling over.
@ -149,9 +147,8 @@ class ACTPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
batch["observation.images"] = [batch[key] for key in self.config.image_features]
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
@ -413,11 +410,10 @@ class ACT(nn.Module):
"actions must be provided when using the variational objective in training mode."
)
batch_size = (
batch["observation.images"]
if "observation.images" in batch
else batch["observation.environment_state"]
).shape[0]
if "observation.images" in batch:
batch_size = batch["observation.images"][0].shape[0]
else:
batch_size = batch["observation.environment_state"].shape[0]
# Prepare the latent for input to the transformer encoder.
if self.config.use_vae and "action" in batch:
@ -490,20 +486,21 @@ class ACT(nn.Module):
all_cam_features = []
all_cam_pos_embeds = []
for cam_index in range(batch["observation.images"].shape[-4]):
cam_features = self.backbone(batch["observation.images"][:, cam_index])["feature_map"]
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
# buffer
# For a list of images, the H and W may vary but H*W is constant.
for img in batch["observation.images"]:
cam_features = self.backbone(img)["feature_map"]
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
cam_features = self.encoder_img_feat_input_proj(cam_features)
# Rearrange features to (sequence, batch, dim).
cam_features = einops.rearrange(cam_features, "b c h w -> (h w) b c")
cam_pos_embed = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c")
all_cam_features.append(cam_features)
all_cam_pos_embeds.append(cam_pos_embed)
# Concatenate camera observation feature maps and positional embeddings along the width dimension,
# and move to (sequence, batch, dim).
all_cam_features = torch.cat(all_cam_features, axis=-1)
encoder_in_tokens.extend(einops.rearrange(all_cam_features, "b c h w -> (h w) b c"))
all_cam_pos_embeds = torch.cat(all_cam_pos_embeds, axis=-1)
encoder_in_pos_embed.extend(einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c"))
encoder_in_tokens.extend(torch.cat(all_cam_features, axis=0))
encoder_in_pos_embed.extend(torch.cat(all_cam_pos_embeds, axis=0))
# Stack all tokens along the sequence dimension.
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)

View File

@ -580,7 +580,7 @@ class ManipulatorRobot:
# Used when record_data=True
follower_goal_pos[name] = goal_pos
goal_pos = goal_pos.numpy().astype(np.int32)
goal_pos = goal_pos.numpy().astype(np.float32)
self.follower_arms[name].write("Goal_Position", goal_pos)
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t
@ -702,7 +702,7 @@ class ManipulatorRobot:
action_sent.append(goal_pos)
# Send goal position to each follower
goal_pos = goal_pos.numpy().astype(np.int32)
goal_pos = goal_pos.numpy().astype(np.float32)
self.follower_arms[name].write("Goal_Position", goal_pos)
return torch.cat(action_sent)

View File

@ -20,6 +20,7 @@ from lerobot.common import (
policies, # noqa: F401
)
from lerobot.common.datasets.transforms import ImageTransformsConfig
from lerobot.common.datasets.video_utils import get_safe_default_codec
@dataclass
@ -35,7 +36,7 @@ class DatasetConfig:
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
revision: str | None = None
use_imagenet_stats: bool = True
video_backend: str = "pyav"
video_backend: str = field(default_factory=get_safe_default_codec)
@dataclass

View File

@ -69,7 +69,7 @@ dependencies = [
"rerun-sdk>=0.21.0",
"termcolor>=2.4.0",
"torch>=2.2.1",
"torchcodec>=0.2.1",
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l'))",
"torchvision>=0.21.0",
"wandb>=0.16.3",
"zarr>=2.17.0",