Merge remote-tracking branch 'origin/2025_02_20_add_dexvla' into 2025_02_20_add_dexvla
This commit is contained in:
commit
575fc92e69
|
@ -41,7 +41,7 @@ jobs:
|
||||||
|
|
||||||
- name: Get changed files
|
- name: Get changed files
|
||||||
id: changed-files
|
id: changed-files
|
||||||
uses: tj-actions/changed-files@v44
|
uses: tj-actions/changed-files@3f54ebb830831fc121d3263c1857cfbdc310cdb9 #v42
|
||||||
with:
|
with:
|
||||||
files: docker/**
|
files: docker/**
|
||||||
json: "true"
|
json: "true"
|
||||||
|
|
|
@ -51,7 +51,7 @@ For a comprehensive list and documentation of these parameters, see the ffmpeg d
|
||||||
### Decoding parameters
|
### Decoding parameters
|
||||||
**Decoder**
|
**Decoder**
|
||||||
We tested two video decoding backends from torchvision:
|
We tested two video decoding backends from torchvision:
|
||||||
- `pyav` (default)
|
- `pyav`
|
||||||
- `video_reader` (requires to build torchvision from source)
|
- `video_reader` (requires to build torchvision from source)
|
||||||
|
|
||||||
**Requested timestamps**
|
**Requested timestamps**
|
||||||
|
|
|
@ -69,6 +69,7 @@ from lerobot.common.datasets.video_utils import (
|
||||||
VideoFrame,
|
VideoFrame,
|
||||||
decode_video_frames,
|
decode_video_frames,
|
||||||
encode_video_frames,
|
encode_video_frames,
|
||||||
|
get_safe_default_codec,
|
||||||
get_video_info,
|
get_video_info,
|
||||||
)
|
)
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
|
@ -462,7 +463,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
||||||
video files are already present on local disk, they won't be downloaded again. Defaults to
|
video files are already present on local disk, they won't be downloaded again. Defaults to
|
||||||
True.
|
True.
|
||||||
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec.
|
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
|
||||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -473,7 +474,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.episodes = episodes
|
self.episodes = episodes
|
||||||
self.tolerance_s = tolerance_s
|
self.tolerance_s = tolerance_s
|
||||||
self.revision = revision if revision else CODEBASE_VERSION
|
self.revision = revision if revision else CODEBASE_VERSION
|
||||||
self.video_backend = video_backend if video_backend else "torchcodec"
|
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||||
self.delta_indices = None
|
self.delta_indices = None
|
||||||
|
|
||||||
# Unused attributes
|
# Unused attributes
|
||||||
|
@ -1027,7 +1028,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
obj.delta_timestamps = None
|
obj.delta_timestamps = None
|
||||||
obj.delta_indices = None
|
obj.delta_indices = None
|
||||||
obj.episode_data_index = None
|
obj.episode_data_index = None
|
||||||
obj.video_backend = video_backend if video_backend is not None else "torchcodec"
|
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import importlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import subprocess
|
import subprocess
|
||||||
|
@ -27,14 +28,23 @@ import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
from datasets.features.features import register_feature
|
from datasets.features.features import register_feature
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchcodec.decoders import VideoDecoder
|
|
||||||
|
|
||||||
|
def get_safe_default_codec():
|
||||||
|
if importlib.util.find_spec("torchcodec"):
|
||||||
|
return "torchcodec"
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
|
||||||
|
)
|
||||||
|
return "pyav"
|
||||||
|
|
||||||
|
|
||||||
def decode_video_frames(
|
def decode_video_frames(
|
||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
timestamps: list[float],
|
timestamps: list[float],
|
||||||
tolerance_s: float,
|
tolerance_s: float,
|
||||||
backend: str = "torchcodec",
|
backend: str | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Decodes video frames using the specified backend.
|
Decodes video frames using the specified backend.
|
||||||
|
@ -43,13 +53,15 @@ def decode_video_frames(
|
||||||
video_path (Path): Path to the video file.
|
video_path (Path): Path to the video file.
|
||||||
timestamps (list[float]): List of timestamps to extract frames.
|
timestamps (list[float]): List of timestamps to extract frames.
|
||||||
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
||||||
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec".
|
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav"..
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Decoded frames.
|
torch.Tensor: Decoded frames.
|
||||||
|
|
||||||
Currently supports torchcodec on cpu and pyav.
|
Currently supports torchcodec on cpu and pyav.
|
||||||
"""
|
"""
|
||||||
|
if backend is None:
|
||||||
|
backend = get_safe_default_codec()
|
||||||
if backend == "torchcodec":
|
if backend == "torchcodec":
|
||||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
||||||
elif backend in ["pyav", "video_reader"]:
|
elif backend in ["pyav", "video_reader"]:
|
||||||
|
@ -173,6 +185,12 @@ def decode_video_frames_torchcodec(
|
||||||
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
||||||
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if importlib.util.find_spec("torchcodec"):
|
||||||
|
from torchcodec.decoders import VideoDecoder
|
||||||
|
else:
|
||||||
|
raise ImportError("torchcodec is required but not available.")
|
||||||
|
|
||||||
# initialize video decoder
|
# initialize video decoder
|
||||||
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
|
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
|
||||||
loaded_frames = []
|
loaded_frames = []
|
||||||
|
|
|
@ -119,9 +119,7 @@ class ACTPolicy(PreTrainedPolicy):
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
if self.config.image_features:
|
if self.config.image_features:
|
||||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||||
batch["observation.images"] = torch.stack(
|
batch["observation.images"] = [batch[key] for key in self.config.image_features]
|
||||||
[batch[key] for key in self.config.image_features], dim=-4
|
|
||||||
)
|
|
||||||
|
|
||||||
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
||||||
# we are ensembling over.
|
# we are ensembling over.
|
||||||
|
@ -149,9 +147,8 @@ class ACTPolicy(PreTrainedPolicy):
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
if self.config.image_features:
|
if self.config.image_features:
|
||||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||||
batch["observation.images"] = torch.stack(
|
batch["observation.images"] = [batch[key] for key in self.config.image_features]
|
||||||
[batch[key] for key in self.config.image_features], dim=-4
|
|
||||||
)
|
|
||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||||
|
|
||||||
|
@ -413,11 +410,10 @@ class ACT(nn.Module):
|
||||||
"actions must be provided when using the variational objective in training mode."
|
"actions must be provided when using the variational objective in training mode."
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_size = (
|
if "observation.images" in batch:
|
||||||
batch["observation.images"]
|
batch_size = batch["observation.images"][0].shape[0]
|
||||||
if "observation.images" in batch
|
else:
|
||||||
else batch["observation.environment_state"]
|
batch_size = batch["observation.environment_state"].shape[0]
|
||||||
).shape[0]
|
|
||||||
|
|
||||||
# Prepare the latent for input to the transformer encoder.
|
# Prepare the latent for input to the transformer encoder.
|
||||||
if self.config.use_vae and "action" in batch:
|
if self.config.use_vae and "action" in batch:
|
||||||
|
@ -490,20 +486,21 @@ class ACT(nn.Module):
|
||||||
all_cam_features = []
|
all_cam_features = []
|
||||||
all_cam_pos_embeds = []
|
all_cam_pos_embeds = []
|
||||||
|
|
||||||
for cam_index in range(batch["observation.images"].shape[-4]):
|
# For a list of images, the H and W may vary but H*W is constant.
|
||||||
cam_features = self.backbone(batch["observation.images"][:, cam_index])["feature_map"]
|
for img in batch["observation.images"]:
|
||||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
|
cam_features = self.backbone(img)["feature_map"]
|
||||||
# buffer
|
|
||||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
cam_features = self.encoder_img_feat_input_proj(cam_features)
|
||||||
|
|
||||||
|
# Rearrange features to (sequence, batch, dim).
|
||||||
|
cam_features = einops.rearrange(cam_features, "b c h w -> (h w) b c")
|
||||||
|
cam_pos_embed = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c")
|
||||||
|
|
||||||
all_cam_features.append(cam_features)
|
all_cam_features.append(cam_features)
|
||||||
all_cam_pos_embeds.append(cam_pos_embed)
|
all_cam_pos_embeds.append(cam_pos_embed)
|
||||||
# Concatenate camera observation feature maps and positional embeddings along the width dimension,
|
|
||||||
# and move to (sequence, batch, dim).
|
encoder_in_tokens.extend(torch.cat(all_cam_features, axis=0))
|
||||||
all_cam_features = torch.cat(all_cam_features, axis=-1)
|
encoder_in_pos_embed.extend(torch.cat(all_cam_pos_embeds, axis=0))
|
||||||
encoder_in_tokens.extend(einops.rearrange(all_cam_features, "b c h w -> (h w) b c"))
|
|
||||||
all_cam_pos_embeds = torch.cat(all_cam_pos_embeds, axis=-1)
|
|
||||||
encoder_in_pos_embed.extend(einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c"))
|
|
||||||
|
|
||||||
# Stack all tokens along the sequence dimension.
|
# Stack all tokens along the sequence dimension.
|
||||||
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
|
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
|
||||||
|
|
|
@ -474,7 +474,7 @@ class ManipulatorRobot:
|
||||||
# Used when record_data=True
|
# Used when record_data=True
|
||||||
follower_goal_pos[name] = goal_pos
|
follower_goal_pos[name] = goal_pos
|
||||||
|
|
||||||
goal_pos = goal_pos.numpy().astype(np.int32)
|
goal_pos = goal_pos.numpy().astype(np.float32)
|
||||||
self.follower_arms[name].write("Goal_Position", goal_pos)
|
self.follower_arms[name].write("Goal_Position", goal_pos)
|
||||||
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t
|
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t
|
||||||
|
|
||||||
|
@ -596,7 +596,7 @@ class ManipulatorRobot:
|
||||||
action_sent.append(goal_pos)
|
action_sent.append(goal_pos)
|
||||||
|
|
||||||
# Send goal position to each follower
|
# Send goal position to each follower
|
||||||
goal_pos = goal_pos.numpy().astype(np.int32)
|
goal_pos = goal_pos.numpy().astype(np.float32)
|
||||||
self.follower_arms[name].write("Goal_Position", goal_pos)
|
self.follower_arms[name].write("Goal_Position", goal_pos)
|
||||||
|
|
||||||
return torch.cat(action_sent)
|
return torch.cat(action_sent)
|
||||||
|
|
|
@ -69,7 +69,13 @@ class WandBLogger:
|
||||||
os.environ["WANDB_SILENT"] = "True"
|
os.environ["WANDB_SILENT"] = "True"
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
wandb_run_id = get_wandb_run_id_from_filesystem(self.log_dir) if cfg.resume else None
|
wandb_run_id = (
|
||||||
|
cfg.wandb.run_id
|
||||||
|
if cfg.wandb.run_id
|
||||||
|
else get_wandb_run_id_from_filesystem(self.log_dir)
|
||||||
|
if cfg.resume
|
||||||
|
else None
|
||||||
|
)
|
||||||
wandb.init(
|
wandb.init(
|
||||||
id=wandb_run_id,
|
id=wandb_run_id,
|
||||||
project=self.cfg.project,
|
project=self.cfg.project,
|
||||||
|
|
|
@ -20,6 +20,7 @@ from lerobot.common import (
|
||||||
policies, # noqa: F401
|
policies, # noqa: F401
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.transforms import ImageTransformsConfig
|
from lerobot.common.datasets.transforms import ImageTransformsConfig
|
||||||
|
from lerobot.common.datasets.video_utils import get_safe_default_codec
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -35,7 +36,7 @@ class DatasetConfig:
|
||||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||||
revision: str | None = None
|
revision: str | None = None
|
||||||
use_imagenet_stats: bool = True
|
use_imagenet_stats: bool = True
|
||||||
video_backend: str = "pyav"
|
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -46,6 +47,7 @@ class WandBConfig:
|
||||||
project: str = "lerobot"
|
project: str = "lerobot"
|
||||||
entity: str | None = None
|
entity: str | None = None
|
||||||
notes: str | None = None
|
notes: str | None = None
|
||||||
|
run_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -79,7 +79,9 @@ class TrainPipelineConfig(HubMixin):
|
||||||
# The entire train config is already loaded, we just need to get the checkpoint dir
|
# The entire train config is already loaded, we just need to get the checkpoint dir
|
||||||
config_path = parser.parse_arg("config_path")
|
config_path = parser.parse_arg("config_path")
|
||||||
if not config_path:
|
if not config_path:
|
||||||
raise ValueError("A config_path is expected when resuming a run.")
|
raise ValueError(
|
||||||
|
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
|
||||||
|
)
|
||||||
if not Path(config_path).resolve().exists():
|
if not Path(config_path).resolve().exists():
|
||||||
raise NotADirectoryError(
|
raise NotADirectoryError(
|
||||||
f"{config_path=} is expected to be a local path. "
|
f"{config_path=} is expected to be a local path. "
|
||||||
|
|
|
@ -69,7 +69,7 @@ dependencies = [
|
||||||
"rerun-sdk>=0.21.0",
|
"rerun-sdk>=0.21.0",
|
||||||
"termcolor>=2.4.0",
|
"termcolor>=2.4.0",
|
||||||
"torch>=2.2.1",
|
"torch>=2.2.1",
|
||||||
"torchcodec>=0.2.1",
|
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l'))",
|
||||||
"torchvision>=0.21.0",
|
"torchvision>=0.21.0",
|
||||||
"wandb>=0.16.3",
|
"wandb>=0.16.3",
|
||||||
"zarr>=2.17.0",
|
"zarr>=2.17.0",
|
||||||
|
|
Loading…
Reference in New Issue