This commit is contained in:
Jade Choghari 2025-03-13 21:53:53 +03:00
parent 1329954dba
commit d1bec3e8ae
6 changed files with 10 additions and 836 deletions

View File

@ -39,7 +39,6 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.video_utils import (
decode_video_frames_torchvision,
encode_video_frames,
decode_video_frames_torchcodec,
)
from lerobot.common.utils.benchmark import TimeBenchmark
@ -68,6 +67,10 @@ def parse_int_or_none(value) -> int | None:
def check_datasets_formats(repo_ids: list) -> None:
for repo_id in repo_ids:
dataset = LeRobotDataset(repo_id)
if dataset.video:
raise ValueError(
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
)
def get_directory_size(directory: Path) -> int:
@ -152,10 +155,6 @@ def decode_video_frames(
) -> torch.Tensor:
if backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
elif backend in ["torchcodec-cpu", "torchcodec-gpu"]:
# Only pass device once depending on the backend
device = "cpu" if backend == "torchcodec-cpu" else "cuda"
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, device=device)
else:
raise NotImplementedError(backend)
@ -189,7 +188,7 @@ def benchmark_decoding(
original_frames = load_original_frames(imgs_dir, timestamps, fps)
result["load_time_images_ms"] = time_benchmark.result_ms / num_frames
frames_np, original_frames_np = frames.cpu().numpy(), original_frames.cpu().numpy()
frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
for i in range(num_frames):
result["mse_values"].append(mean_squared_error(original_frames_np[i], frames_np[i]))
result["psnr_values"].append(

View File

@ -660,48 +660,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
item = {**item, **padding}
for key, val in query_result.items():
item[key] = val
if len(self.meta.video_keys) > 0:
current_ts = item["timestamp"].item()
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
# if what is returned is all the info that i used query_timestamps, episode
# percentage of chance, 30% cpu, gpu
# video_frames = self._query_videos(query_timestamps, ep_idx)
# item = {**video_frames, **item}
# jade - instead of decoding video, return video path & timestamps
# hack only add metadata
item["video_paths"] = {
vid_key: self.root / self.meta.get_video_file_path(ep_idx, vid_key)
for vid_key in query_timestamps.keys()
}
item["query_timestamps"] = query_timestamps
if self.image_transforms is not None:
breakpoint()
image_keys = self.meta.camera_keys
for cam in image_keys:
item[cam] = self.image_transforms(item[cam])
# Add task as a string
task_idx = item["task_index"].item()
item["task"] = self.meta.tasks[task_idx]
return item
def __getitem2__(self, idx) -> dict:
item = self.hf_dataset[idx]
ep_idx = item["episode_index"].item()
query_indices = None
# data logic
if self.delta_indices is not None:
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
query_indices, padding = self._get_query_indices(idx, current_ep_idx) #
query_result = self._query_hf_dataset(query_indices)
item = {**item, **padding}
for key, val in query_result.items():
item[key] = val
# video logic
if len(self.meta.video_keys) > 0:
current_ts = item["timestamp"].item()
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
@ -718,6 +677,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
item["task"] = self.meta.tasks[task_idx]
return item
def __repr__(self):
feature_keys = list(self.features)
return (

View File

@ -127,67 +127,6 @@ def decode_video_frames_torchvision(
return closest_frames
def decode_video_frames_torchcodec(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
device: str = "cpu",
log_loaded_timestamps: bool = False,
) -> torch.Tensor:
"""Loads frames associated with the requested timestamps of a video using torchcodec."""
video_path = str(video_path)
# initialize video decoder
from torchcodec.decoders import VideoDecoder
decoder = VideoDecoder(video_path, device=device)
loaded_frames = []
loaded_ts = []
# get metadata for frame information
metadata = decoder.metadata
average_fps = metadata.average_fps
# convert timestamps to frame indices
frame_indices = [int(ts * average_fps) for ts in timestamps]
# retrieve frames based on indices
frames_batch = decoder.get_frames_at(indices=frame_indices)
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds):
loaded_frames.append(frame)
loaded_ts.append(pts.item())
if log_loaded_timestamps:
logging.info(f"Frame loaded at timestamp={pts:.4f}")
query_ts = torch.tensor(timestamps)
loaded_ts = torch.tensor(loaded_ts)
# compute distances between each query timestamp and loaded timestamps
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
assert is_within_tol.all(), (
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
"It means that the closest frame that can be loaded from the video is too far away in time."
"This might be due to synchronization issues with timestamps during data collection."
"To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
)
# get closest frames to the query timestamps
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
closest_ts = loaded_ts[argmin_]
if log_loaded_timestamps:
logging.info(f"{closest_ts=}")
# convert to float32 in [0,1] range (channel first)
closest_frames = closest_frames.type(torch.float32) / 255
assert len(timestamps) == len(closest_frames)
return closest_frames
def encode_video_frames(
imgs_dir: Path | str,
video_path: Path | str,

View File

@ -23,7 +23,7 @@ import torch
from termcolor import colored
from torch.amp import GradScaler
from torch.optim import Optimizer
from pathlib import Path
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import cycle
@ -51,58 +51,7 @@ from lerobot.common.utils.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.eval import eval_policy
from lerobot.common.datasets.video_utils import (
decode_video_frames_torchvision, decode_video_frames_torchcodec
)
# let's define a custom fn
def custom_collate_fn(batch):
# always in the cuda, getitem is on cpu,
# then implement mixed
"""
Custom collate function that decodes videos on GPU/CPU.
Converts the batch to a dictionary with keys representing each field.
Returns a tensor for video frames instead of a list.
"""
# know when it is called
final_batch = {}
is_main_process = torch.utils.data.get_worker_info() is None
# the batch is given as a list, we need to return a dict
for item in batch:
# process video decoding for each item
if "video_paths" in item and "query_timestamps" in item:
for vid_key, video_path in item["video_paths"].items():
# decode video frames based on timestamps
timestamps = item["query_timestamps"][vid_key]
# ✅ Use CUDA only in the main process
device = "cuda" if is_main_process else "cpu"
frames = decode_video_frames_torchcodec(
video_path=Path(video_path),
timestamps=timestamps,
tolerance_s=0.02,
# backend="pyav",
log_loaded_timestamps=False,
device=device, # ✅ Keeps CUDA safe
)
# stack frames for this video key and add directly to the item
item[vid_key] = frames
# add item data (both video and non-video) to final_batch
for key, value in item.items():
if key not in final_batch:
final_batch[key] = []
final_batch[key].append(value)
# now, stack tensors for each key in final_batch
# this is needed to ensure that video frames (and any other tensor fields) are combined
# into a single tensor per field, rather than a list of tensors!
for key in final_batch:
if isinstance(final_batch[key][0], torch.Tensor):
final_batch[key] = torch.stack(final_batch[key]) # stack tensors if needed
return final_batch
def update_policy(
train_metrics: MetricsTracker,
@ -233,11 +182,12 @@ def train(cfg: TrainPipelineConfig):
shuffle=shuffle,
sampler=sampler,
pin_memory=device.type != "cpu",
collate_fn=custom_collate_fn,
drop_last=False,
)
dl_iter = cycle(dataloader)
policy.train()
train_metrics = {
"loss": AverageMeter("loss", ":.3f"),
"grad_norm": AverageMeter("grdn", ":.3f"),
@ -255,6 +205,7 @@ def train(cfg: TrainPipelineConfig):
start_time = time.perf_counter()
batch = next(dl_iter)
train_tracker.dataloading_s = time.perf_counter() - start_time
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(device, non_blocking=True)
@ -280,7 +231,6 @@ def train(cfg: TrainPipelineConfig):
if is_log_step:
logging.info(train_tracker)
breakpoint()
if wandb_logger:
wandb_log_dict = train_tracker.to_dict()
if output_dict:

Binary file not shown.

View File

@ -1,674 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"import gym_pusht # noqa: F401\n",
"import gymnasium as gym\n",
"import imageio\n",
"import numpy\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Select your device\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Provide the [hugging face repo id](https://huggingface.co/lerobot/diffusion_pusht):\n",
"pretrained_policy_path = \"IliaLarchenko/dot_pusht_keypoints\""
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"from lerobot.common.policies.dot.modeling_dot import DOTPolicy\n",
"policy = DOTPolicy.from_pretrained(pretrained_policy_path, map_location=device)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"env = gym.make(\n",
" \"gym_pusht/PushT-v0\",\n",
" obs_type=\"environment_state_agent_pos\",\n",
" max_episode_steps=300,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'observation.state': PolicyFeature(type=<FeatureType.STATE: 'STATE'>, shape=(2,)), 'observation.environment_state': PolicyFeature(type=<FeatureType.ENV: 'ENV'>, shape=(16,))}\n",
"Dict('agent_pos': Box(0.0, 512.0, (2,), float64), 'environment_state': Box(0.0, 512.0, (16,), float64))\n"
]
}
],
"source": [
"print(policy.config.input_features)\n",
"print(env.observation_space)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'action': PolicyFeature(type=<FeatureType.ACTION: 'ACTION'>, shape=(2,))}\n",
"Box(0.0, 512.0, (2,), float32)\n"
]
}
],
"source": [
"print(policy.config.output_features)\n",
"print(env.action_space)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"policy.reset()\n",
"numpy_observation, info = env.reset(seed=42)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"# Prepare to collect every rewards and all the frames of the episode,\n",
"# from initial state to final state.\n",
"rewards = []\n",
"frames = []\n",
"\n",
"# Render frame of the initial state\n",
"frames.append(env.render())"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"step=0 reward=np.float64(0.0) terminated=False\n",
"step=1 reward=np.float64(0.0) terminated=False\n",
"step=2 reward=np.float64(0.0) terminated=False\n",
"step=3 reward=np.float64(0.0) terminated=False\n",
"step=4 reward=np.float64(0.0) terminated=False\n",
"step=5 reward=np.float64(0.0) terminated=False\n",
"step=6 reward=np.float64(0.0) terminated=False\n",
"step=7 reward=np.float64(0.0) terminated=False\n",
"step=8 reward=np.float64(0.0) terminated=False\n",
"step=9 reward=np.float64(0.0) terminated=False\n",
"step=10 reward=np.float64(0.0) terminated=False\n",
"step=11 reward=np.float64(0.0) terminated=False\n",
"step=12 reward=np.float64(0.0) terminated=False\n",
"step=13 reward=np.float64(0.0) terminated=False\n",
"step=14 reward=np.float64(0.0) terminated=False\n",
"step=15 reward=np.float64(0.0) terminated=False\n",
"step=16 reward=np.float64(0.0) terminated=False\n",
"step=17 reward=np.float64(0.0) terminated=False\n",
"step=18 reward=np.float64(0.0) terminated=False\n",
"step=19 reward=np.float64(0.0) terminated=False\n",
"step=20 reward=np.float64(0.0) terminated=False\n",
"step=21 reward=np.float64(0.0) terminated=False\n",
"step=22 reward=np.float64(0.0009941544780861455) terminated=False\n",
"step=23 reward=np.float64(0.033647507519038757) terminated=False\n",
"step=24 reward=np.float64(0.07026086006261555) terminated=False\n",
"step=25 reward=np.float64(0.10069667553409196) terminated=False\n",
"step=26 reward=np.float64(0.11389926069925992) terminated=False\n",
"step=27 reward=np.float64(0.12027077768723497) terminated=False\n",
"step=28 reward=np.float64(0.12486582623684722) terminated=False\n",
"step=29 reward=np.float64(0.12815916861048604) terminated=False\n",
"step=30 reward=np.float64(0.1303391815805222) terminated=False\n",
"step=31 reward=np.float64(0.1315231117258188) terminated=False\n",
"step=32 reward=np.float64(0.13221640549835664) terminated=False\n",
"step=33 reward=np.float64(0.13254763259209015) terminated=False\n",
"step=34 reward=np.float64(0.13263558368425837) terminated=False\n",
"step=35 reward=np.float64(0.13263932937572084) terminated=False\n",
"step=36 reward=np.float64(0.13263932937572084) terminated=False\n",
"step=37 reward=np.float64(0.13263932937572084) terminated=False\n",
"step=38 reward=np.float64(0.13263932937572084) terminated=False\n",
"step=39 reward=np.float64(0.13263932937572084) terminated=False\n",
"step=40 reward=np.float64(0.13263932937572084) terminated=False\n",
"step=41 reward=np.float64(0.13263932937572084) terminated=False\n",
"step=42 reward=np.float64(0.13263932937572084) terminated=False\n",
"step=43 reward=np.float64(0.13263932937572084) terminated=False\n",
"step=44 reward=np.float64(0.13263932937572084) terminated=False\n",
"step=45 reward=np.float64(0.13263932937572084) terminated=False\n",
"step=46 reward=np.float64(0.13263932937572084) terminated=False\n",
"step=47 reward=np.float64(0.13263932937572084) terminated=False\n",
"step=48 reward=np.float64(0.13263932937572084) terminated=False\n",
"step=49 reward=np.float64(0.13263932937572084) terminated=False\n",
"step=50 reward=np.float64(0.13263932937572084) terminated=False\n",
"step=51 reward=np.float64(0.13263932937572084) terminated=False\n",
"step=52 reward=np.float64(0.14872285364307145) terminated=False\n",
"step=53 reward=np.float64(0.19847005044261715) terminated=False\n",
"step=54 reward=np.float64(0.24338272205852812) terminated=False\n",
"step=55 reward=np.float64(0.2667243347061481) terminated=False\n",
"step=56 reward=np.float64(0.2691675276421592) terminated=False\n",
"step=57 reward=np.float64(0.3018254762158707) terminated=False\n",
"step=58 reward=np.float64(0.3613501686331564) terminated=False\n",
"step=59 reward=np.float64(0.4613512243665896) terminated=False\n",
"step=60 reward=np.float64(0.5617656688929643) terminated=False\n",
"step=61 reward=np.float64(0.5747609180351871) terminated=False\n",
"step=62 reward=np.float64(0.507048651118485) terminated=False\n",
"step=63 reward=np.float64(0.44332042287270484) terminated=False\n",
"step=64 reward=np.float64(0.3993804222553378) terminated=False\n",
"step=65 reward=np.float64(0.36941592278664487) terminated=False\n",
"step=66 reward=np.float64(0.36941592278664487) terminated=False\n",
"step=67 reward=np.float64(0.36941592278664487) terminated=False\n",
"step=68 reward=np.float64(0.36941592278664487) terminated=False\n",
"step=69 reward=np.float64(0.36941592278664487) terminated=False\n",
"step=70 reward=np.float64(0.36941592278664487) terminated=False\n",
"step=71 reward=np.float64(0.36941592278664487) terminated=False\n",
"step=72 reward=np.float64(0.36941592278664487) terminated=False\n",
"step=73 reward=np.float64(0.36941592278664487) terminated=False\n",
"step=74 reward=np.float64(0.36941592278664487) terminated=False\n",
"step=75 reward=np.float64(0.4322328474940646) terminated=False\n",
"step=76 reward=np.float64(0.4818152566968738) terminated=False\n",
"step=77 reward=np.float64(0.5252535051167734) terminated=False\n",
"step=78 reward=np.float64(0.5586446249197407) terminated=False\n",
"step=79 reward=np.float64(0.5885022076599307) terminated=False\n",
"step=80 reward=np.float64(0.5977994643852952) terminated=False\n",
"step=81 reward=np.float64(0.597859201570885) terminated=False\n",
"step=82 reward=np.float64(0.597859201570885) terminated=False\n",
"step=83 reward=np.float64(0.597859201570885) terminated=False\n",
"step=84 reward=np.float64(0.597859201570885) terminated=False\n",
"step=85 reward=np.float64(0.597859201570885) terminated=False\n",
"step=86 reward=np.float64(0.597859201570885) terminated=False\n",
"step=87 reward=np.float64(0.597859201570885) terminated=False\n",
"step=88 reward=np.float64(0.597859201570885) terminated=False\n",
"step=89 reward=np.float64(0.6876341127908622) terminated=False\n",
"step=90 reward=np.float64(0.8166289152424572) terminated=False\n",
"step=91 reward=np.float64(0.9421614978354362) terminated=False\n",
"step=92 reward=np.float64(0.9441608976568224) terminated=False\n",
"step=93 reward=np.float64(0.9104163604296966) terminated=False\n",
"step=94 reward=np.float64(0.909182661371769) terminated=False\n",
"step=95 reward=np.float64(0.909182661371769) terminated=False\n",
"step=96 reward=np.float64(0.909182661371769) terminated=False\n",
"step=97 reward=np.float64(0.909182661371769) terminated=False\n",
"step=98 reward=np.float64(0.909182661371769) terminated=False\n",
"step=99 reward=np.float64(0.909182661371769) terminated=False\n",
"step=100 reward=np.float64(0.909182661371769) terminated=False\n",
"step=101 reward=np.float64(0.909182661371769) terminated=False\n",
"step=102 reward=np.float64(0.909182661371769) terminated=False\n",
"step=103 reward=np.float64(0.9340357871805705) terminated=False\n",
"step=104 reward=np.float64(0.8851102121651142) terminated=False\n",
"step=105 reward=np.float64(0.8809768749693764) terminated=False\n",
"step=106 reward=np.float64(0.8809768749693764) terminated=False\n",
"step=107 reward=np.float64(0.8809768749693764) terminated=False\n",
"step=108 reward=np.float64(0.8809768749693764) terminated=False\n",
"step=109 reward=np.float64(0.8809768749693764) terminated=False\n",
"step=110 reward=np.float64(0.8809768749693764) terminated=False\n",
"step=111 reward=np.float64(0.8809768749693764) terminated=False\n",
"step=112 reward=np.float64(0.8809768749693764) terminated=False\n",
"step=113 reward=np.float64(0.8809768749693764) terminated=False\n",
"step=114 reward=np.float64(0.8809768749693764) terminated=False\n",
"step=115 reward=np.float64(0.8809768749693764) terminated=False\n",
"step=116 reward=np.float64(0.8809768749693764) terminated=False\n",
"step=117 reward=np.float64(0.8809768749693764) terminated=False\n",
"step=118 reward=np.float64(0.8809768749693764) terminated=False\n",
"step=119 reward=np.float64(0.9518089158714292) terminated=False\n",
"step=120 reward=np.float64(0.9405458729516311) terminated=False\n",
"step=121 reward=np.float64(0.935511214687435) terminated=False\n",
"step=122 reward=np.float64(0.935511214687435) terminated=False\n",
"step=123 reward=np.float64(0.935511214687435) terminated=False\n",
"step=124 reward=np.float64(0.935511214687435) terminated=False\n",
"step=125 reward=np.float64(0.935511214687435) terminated=False\n",
"step=126 reward=np.float64(0.935511214687435) terminated=False\n",
"step=127 reward=np.float64(0.935511214687435) terminated=False\n",
"step=128 reward=np.float64(0.935511214687435) terminated=False\n",
"step=129 reward=np.float64(0.935511214687435) terminated=False\n",
"step=130 reward=np.float64(0.935511214687435) terminated=False\n",
"step=131 reward=np.float64(0.935511214687435) terminated=False\n",
"step=132 reward=np.float64(0.935511214687435) terminated=False\n",
"step=133 reward=np.float64(0.935511214687435) terminated=False\n",
"step=134 reward=np.float64(0.935511214687435) terminated=False\n",
"step=135 reward=np.float64(0.9534990217822209) terminated=False\n",
"step=136 reward=np.float64(0.9596585109597399) terminated=False\n",
"step=137 reward=np.float64(0.882875733420291) terminated=False\n",
"step=138 reward=np.float64(0.8277880190838034) terminated=False\n",
"step=139 reward=np.float64(0.8211529871155266) terminated=False\n",
"step=140 reward=np.float64(0.8211529871155266) terminated=False\n",
"step=141 reward=np.float64(0.8211529871155266) terminated=False\n",
"step=142 reward=np.float64(0.8211529871155266) terminated=False\n",
"step=143 reward=np.float64(0.8211529871155266) terminated=False\n",
"step=144 reward=np.float64(0.8211529871155266) terminated=False\n",
"step=145 reward=np.float64(0.8211529871155266) terminated=False\n",
"step=146 reward=np.float64(0.8211529871155266) terminated=False\n",
"step=147 reward=np.float64(0.8211529871155266) terminated=False\n",
"step=148 reward=np.float64(0.8211529871155266) terminated=False\n",
"step=149 reward=np.float64(0.8211529871155266) terminated=False\n",
"step=150 reward=np.float64(0.8211529871155266) terminated=False\n",
"step=151 reward=np.float64(0.8211529871155266) terminated=False\n",
"step=152 reward=np.float64(0.8211529871155266) terminated=False\n",
"step=153 reward=np.float64(0.8211529871155266) terminated=False\n",
"step=154 reward=np.float64(0.856408395120982) terminated=False\n",
"step=155 reward=np.float64(0.9304040416833055) terminated=False\n",
"step=156 reward=np.float64(0.9770812279113622) terminated=False\n",
"step=157 reward=np.float64(0.9944597232685968) terminated=False\n",
"step=158 reward=np.float64(0.9944597232685968) terminated=False\n",
"step=159 reward=np.float64(0.9944597232685968) terminated=False\n",
"step=160 reward=np.float64(0.9944597232685968) terminated=False\n",
"step=161 reward=np.float64(0.9944597232685968) terminated=False\n",
"step=162 reward=np.float64(0.9944597232685968) terminated=False\n",
"step=163 reward=np.float64(0.9944597232685968) terminated=False\n",
"step=164 reward=np.float64(0.9944597232685968) terminated=False\n",
"step=165 reward=np.float64(0.9944597232685968) terminated=False\n",
"step=166 reward=np.float64(0.9944597232685968) terminated=False\n",
"step=167 reward=np.float64(0.9944597232685968) terminated=False\n",
"step=168 reward=np.float64(0.9944597232685968) terminated=False\n",
"step=169 reward=np.float64(0.9820541217675358) terminated=False\n",
"step=170 reward=np.float64(0.8765528119646949) terminated=False\n",
"step=171 reward=np.float64(0.8231919366320396) terminated=False\n",
"step=172 reward=np.float64(0.7926155231821123) terminated=False\n",
"step=173 reward=np.float64(0.7902960563492054) terminated=False\n",
"step=174 reward=np.float64(0.7902960563492054) terminated=False\n",
"step=175 reward=np.float64(0.7902960563492054) terminated=False\n",
"step=176 reward=np.float64(0.7902960563492054) terminated=False\n",
"step=177 reward=np.float64(0.7902960563492054) terminated=False\n",
"step=178 reward=np.float64(0.7902960563492054) terminated=False\n",
"step=179 reward=np.float64(0.7902960563492054) terminated=False\n",
"step=180 reward=np.float64(0.7902960563492054) terminated=False\n",
"step=181 reward=np.float64(0.8158199658870418) terminated=False\n",
"step=182 reward=np.float64(0.8191627090126786) terminated=False\n",
"step=183 reward=np.float64(0.8172224839948001) terminated=False\n",
"step=184 reward=np.float64(0.8172224839948001) terminated=False\n",
"step=185 reward=np.float64(0.8172224839948001) terminated=False\n",
"step=186 reward=np.float64(0.8172224839948001) terminated=False\n",
"step=187 reward=np.float64(0.8172224839948001) terminated=False\n",
"step=188 reward=np.float64(0.8172224839948001) terminated=False\n",
"step=189 reward=np.float64(0.8172224839948001) terminated=False\n",
"step=190 reward=np.float64(0.8172224839948001) terminated=False\n",
"step=191 reward=np.float64(0.8172224839948001) terminated=False\n",
"step=192 reward=np.float64(0.8172224839948001) terminated=False\n",
"step=193 reward=np.float64(0.8172224839948001) terminated=False\n",
"step=194 reward=np.float64(0.8172224839948001) terminated=False\n",
"step=195 reward=np.float64(0.8172224839948001) terminated=False\n",
"step=196 reward=np.float64(0.8172224839948001) terminated=False\n",
"step=197 reward=np.float64(0.878735078350138) terminated=False\n",
"step=198 reward=np.float64(0.8564816396314117) terminated=False\n",
"step=199 reward=np.float64(0.7970005244772627) terminated=False\n",
"step=200 reward=np.float64(0.7688729860960439) terminated=False\n",
"step=201 reward=np.float64(0.7671148163486476) terminated=False\n",
"step=202 reward=np.float64(0.7671148163486476) terminated=False\n",
"step=203 reward=np.float64(0.7671148163486476) terminated=False\n",
"step=204 reward=np.float64(0.7671148163486476) terminated=False\n",
"step=205 reward=np.float64(0.7671148163486476) terminated=False\n",
"step=206 reward=np.float64(0.7671148163486476) terminated=False\n",
"step=207 reward=np.float64(0.7671148163486476) terminated=False\n",
"step=208 reward=np.float64(0.7671148163486476) terminated=False\n",
"step=209 reward=np.float64(0.7671148163486476) terminated=False\n",
"step=210 reward=np.float64(0.7671148163486476) terminated=False\n",
"step=211 reward=np.float64(0.7671148163486476) terminated=False\n",
"step=212 reward=np.float64(0.7671148163486476) terminated=False\n",
"step=213 reward=np.float64(0.7671148163486476) terminated=False\n",
"step=214 reward=np.float64(0.7671148163486476) terminated=False\n",
"step=215 reward=np.float64(0.7671148163486476) terminated=False\n",
"step=216 reward=np.float64(0.7671148163486476) terminated=False\n",
"step=217 reward=np.float64(0.8045352949993082) terminated=False\n",
"step=218 reward=np.float64(0.8328184612705187) terminated=False\n",
"step=219 reward=np.float64(0.8558996801195216) terminated=False\n",
"step=220 reward=np.float64(0.8576887923798919) terminated=False\n",
"step=221 reward=np.float64(0.8576887923798919) terminated=False\n",
"step=222 reward=np.float64(0.8576887923798919) terminated=False\n",
"step=223 reward=np.float64(0.8576887923798919) terminated=False\n",
"step=224 reward=np.float64(0.8576887923798919) terminated=False\n",
"step=225 reward=np.float64(0.8576887923798919) terminated=False\n",
"step=226 reward=np.float64(0.8576887923798919) terminated=False\n",
"step=227 reward=np.float64(0.8576887923798919) terminated=False\n",
"step=228 reward=np.float64(0.8576887923798919) terminated=False\n",
"step=229 reward=np.float64(0.8576887923798919) terminated=False\n",
"step=230 reward=np.float64(0.8576887923798919) terminated=False\n",
"step=231 reward=np.float64(0.8576887923798919) terminated=False\n",
"step=232 reward=np.float64(0.8576887923798919) terminated=False\n",
"step=233 reward=np.float64(0.8576887923798919) terminated=False\n",
"step=234 reward=np.float64(0.935454295639811) terminated=False\n",
"step=235 reward=np.float64(0.9874094853870982) terminated=False\n",
"step=236 reward=np.float64(0.9719847917595543) terminated=False\n",
"step=237 reward=np.float64(0.9719847917595543) terminated=False\n",
"step=238 reward=np.float64(0.9719847917595543) terminated=False\n",
"step=239 reward=np.float64(0.9719847917595543) terminated=False\n",
"step=240 reward=np.float64(0.9719847917595543) terminated=False\n",
"step=241 reward=np.float64(0.9719847917595543) terminated=False\n",
"step=242 reward=np.float64(0.9719847917595543) terminated=False\n",
"step=243 reward=np.float64(0.9719847917595543) terminated=False\n",
"step=244 reward=np.float64(0.9719847917595543) terminated=False\n",
"step=245 reward=np.float64(0.9719847917595543) terminated=False\n",
"step=246 reward=np.float64(0.9719847917595543) terminated=False\n",
"step=247 reward=np.float64(0.9719847917595543) terminated=False\n",
"step=248 reward=np.float64(0.9719847917595543) terminated=False\n",
"step=249 reward=np.float64(0.9719847917595543) terminated=False\n",
"step=250 reward=np.float64(0.9719847917595543) terminated=False\n",
"step=251 reward=np.float64(0.9719847917595543) terminated=False\n",
"step=252 reward=np.float64(0.9719847917595543) terminated=False\n",
"step=253 reward=np.float64(0.9719847917595543) terminated=False\n",
"step=254 reward=np.float64(0.9727790631955697) terminated=False\n",
"step=255 reward=np.float64(0.946125141646605) terminated=False\n",
"step=256 reward=np.float64(0.9368755165399575) terminated=False\n",
"step=257 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=258 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=259 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=260 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=261 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=262 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=263 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=264 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=265 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=266 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=267 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=268 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=269 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=270 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=271 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=272 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=273 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=274 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=275 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=276 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=277 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=278 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=279 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=280 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=281 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=282 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=283 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=284 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=285 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=286 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=287 reward=np.float64(0.9360986686274937) terminated=False\n",
"step=288 reward=np.float64(0.9316550466986755) terminated=False\n",
"step=289 reward=np.float64(0.9218676877631473) terminated=False\n",
"step=290 reward=np.float64(0.9213441513220694) terminated=False\n",
"step=291 reward=np.float64(0.9213441513220694) terminated=False\n",
"step=292 reward=np.float64(0.9213441513220694) terminated=False\n",
"step=293 reward=np.float64(0.9213441513220694) terminated=False\n",
"step=294 reward=np.float64(0.9213441513220694) terminated=False\n",
"step=295 reward=np.float64(0.9213441513220694) terminated=False\n",
"step=296 reward=np.float64(0.9213441513220694) terminated=False\n",
"step=297 reward=np.float64(0.9213441513220694) terminated=False\n",
"step=298 reward=np.float64(0.9213441513220694) terminated=False\n",
"step=299 reward=np.float64(0.9213441513220694) terminated=False\n"
]
}
],
"source": [
"step = 0\n",
"done = False\n",
"\n",
"while not done:\n",
" # Prepare observation for the policy\n",
" state = torch.from_numpy(numpy_observation[\"agent_pos\"]) # Agent position\n",
" env_state = torch.from_numpy(numpy_observation[\"environment_state\"]) # Environment state\n",
"\n",
" # Convert to float32\n",
" state = state.to(torch.float32)\n",
" env_state = env_state.to(torch.float32)\n",
"\n",
" # Send data tensors from CPU to GPU\n",
" state = state.to(device, non_blocking=True)\n",
" env_state = env_state.to(device, non_blocking=True)\n",
"\n",
" # Add extra (empty) batch dimension, required to forward the policy\n",
" state = state.unsqueeze(0)\n",
" env_state = env_state.unsqueeze(0)\n",
"\n",
" # Create the policy input dictionary\n",
" observation = {\n",
" \"observation.state\": state,\n",
" \"observation.environment_state\": env_state, # Add environment_state here\n",
" }\n",
"\n",
" # Predict the next action with respect to the current observation\n",
" with torch.inference_mode():\n",
" action = policy.select_action(observation)\n",
"\n",
" # Prepare the action for the environment\n",
" numpy_action = action.squeeze(0).to(\"cpu\").numpy()\n",
"\n",
" # Step through the environment and receive a new observation\n",
" numpy_observation, reward, terminated, truncated, info = env.step(numpy_action)\n",
" print(f\"{step=} {reward=} {terminated=}\")\n",
"\n",
" # Keep track of all the rewards and frames\n",
" rewards.append(reward)\n",
" frames.append(env.render())\n",
"\n",
" # The rollout is considered done when the success state is reached (i.e. terminated is True),\n",
" # or the maximum number of iterations is reached (i.e. truncated is True)\n",
" done = terminated or truncated or done\n",
" step += 1\n"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Failure!\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (680, 680) to (688, 688) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to 1 (risking incompatibility).\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Video of the evaluation is available in '/home/lerobot/output/rollout.mp4'.\n"
]
}
],
"source": [
"if terminated:\n",
" print(\"Success!\")\n",
"else:\n",
" print(\"Failure!\")\n",
"\n",
"# Get the speed of environment (i.e. its number of frames per second).\n",
"fps = env.metadata[\"render_fps\"]\n",
"\n",
"# Encode all frames into a mp4 video.\n",
"video_path = \"/home/lerobot/output/rollout.mp4\"\n",
"imageio.mimsave(str(video_path), numpy.stack(frames), fps=fps)\n",
"\n",
"print(f\"Video of the evaluation is available in '{video_path}'.\")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"#now on aloha\n",
"import imageio\n",
"import gymnasium as gym\n",
"import numpy as np\n",
"import gym_aloha\n",
"env = gym.make(\n",
" \"gym_aloha/AlohaInsertion-v0\",\n",
" obs_type=\"pixels\",\n",
" max_episode_steps=300,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/envs/lerobot/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"Downloading: \"https://download.pytorch.org/models/resnet18-f37072fd.pth\" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth\n",
"100%|██████████| 44.7M/44.7M [00:00<00:00, 172MB/s]\n"
]
}
],
"source": [
"from lerobot.common.policies.dot.modeling_dot import DOTPolicy\n",
"pretrained_policy_path = \"IliaLarchenko/dot_bimanual_insert\"\n",
"policy = DOTPolicy.from_pretrained(pretrained_policy_path, map_location=device)\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'observation.images.top': PolicyFeature(type=<FeatureType.VISUAL: 'VISUAL'>, shape=(3, 480, 640)), 'observation.state': PolicyFeature(type=<FeatureType.STATE: 'STATE'>, shape=(14,))}\n",
"Dict('top': Box(0, 255, (480, 640, 3), uint8))\n"
]
}
],
"source": [
"# We can verify that the shapes of the features expected by the policy match the ones from the observations\n",
"# produced by the environment\n",
"print(policy.config.input_features)\n",
"print(env.observation_space)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'action': PolicyFeature(type=<FeatureType.ACTION: 'ACTION'>, shape=(14,))}\n",
"Box(-1.0, 1.0, (14,), float32)\n"
]
}
],
"source": [
"# Similarly, we can check that the actions produced by the policy will match the actions expected by the\n",
"# environment\n",
"print(policy.config.output_features)\n",
"print(env.action_space)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"ename": "FatalError",
"evalue": "gladLoadGL error",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mFatalError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[16], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Reset the policy and environments to prepare for rollout\u001b[39;00m\n\u001b[1;32m 2\u001b[0m policy\u001b[38;5;241m.\u001b[39mreset()\n\u001b[0;32m----> 3\u001b[0m numpy_observation, info \u001b[38;5;241m=\u001b[39m \u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m(\u001b[49m\u001b[43mseed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m42\u001b[39;49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gymnasium/wrappers/time_limit.py:75\u001b[0m, in \u001b[0;36mTimeLimit.reset\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 66\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero.\u001b[39;00m\n\u001b[1;32m 67\u001b[0m \n\u001b[1;32m 68\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;124;03m The reset environment\u001b[39;00m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_elapsed_steps \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m---> 75\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gymnasium/wrappers/order_enforcing.py:61\u001b[0m, in \u001b[0;36mOrderEnforcing.reset\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Resets the environment with `kwargs`.\"\"\"\u001b[39;00m\n\u001b[1;32m 60\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_has_reset \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m---> 61\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gymnasium/wrappers/env_checker.py:57\u001b[0m, in \u001b[0;36mPassiveEnvChecker.reset\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchecked_reset:\n\u001b[1;32m 56\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchecked_reset \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m---> 57\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43menv_reset_passive_checker\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 58\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 59\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39menv\u001b[38;5;241m.\u001b[39mreset(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gymnasium/utils/passive_env_checker.py:186\u001b[0m, in \u001b[0;36menv_reset_passive_checker\u001b[0;34m(env, **kwargs)\u001b[0m\n\u001b[1;32m 181\u001b[0m logger\u001b[38;5;241m.\u001b[39mdeprecation(\n\u001b[1;32m 182\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCurrent gymnasium version requires that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 183\u001b[0m )\n\u001b[1;32m 185\u001b[0m \u001b[38;5;66;03m# Checks the result of env.reset with kwargs\u001b[39;00m\n\u001b[0;32m--> 186\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(result, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[1;32m 189\u001b[0m logger\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[1;32m 190\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(result)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m`\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 191\u001b[0m )\n",
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gym_aloha/env.py:166\u001b[0m, in \u001b[0;36mAlohaEnv.reset\u001b[0;34m(self, seed, options)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtask)\n\u001b[0;32m--> 166\u001b[0m raw_obs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_env\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m observation \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_raw_obs(raw_obs\u001b[38;5;241m.\u001b[39mobservation)\n\u001b[1;32m 170\u001b[0m info \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mis_success\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mFalse\u001b[39;00m}\n",
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/rl/control.py:89\u001b[0m, in \u001b[0;36mEnvironment.reset\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_physics\u001b[38;5;241m.\u001b[39mreset_context():\n\u001b[1;32m 87\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_task\u001b[38;5;241m.\u001b[39minitialize_episode(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_physics)\n\u001b[0;32m---> 89\u001b[0m observation \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_task\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_observation\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_physics\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 90\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_flat_observation:\n\u001b[1;32m 91\u001b[0m observation \u001b[38;5;241m=\u001b[39m flatten_observation(observation)\n",
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gym_aloha/tasks/sim.py:92\u001b[0m, in \u001b[0;36mBimanualViperXTask.get_observation\u001b[0;34m(self, physics)\u001b[0m\n\u001b[1;32m 90\u001b[0m obs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124menv_state\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_env_state(physics)\n\u001b[1;32m 91\u001b[0m obs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimages\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m {}\n\u001b[0;32m---> 92\u001b[0m obs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimages\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtop\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[43mphysics\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrender\u001b[49m\u001b[43m(\u001b[49m\u001b[43mheight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m480\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwidth\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m640\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcamera_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtop\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 93\u001b[0m obs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimages\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mangle\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m physics\u001b[38;5;241m.\u001b[39mrender(height\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m480\u001b[39m, width\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m640\u001b[39m, camera_id\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mangle\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 94\u001b[0m obs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimages\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvis\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m physics\u001b[38;5;241m.\u001b[39mrender(height\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m480\u001b[39m, width\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m640\u001b[39m, camera_id\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfront_close\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/mujoco/engine.py:223\u001b[0m, in \u001b[0;36mPhysics.render\u001b[0;34m(self, height, width, camera_id, overlays, depth, segmentation, scene_option, render_flag_overrides, scene_callback)\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mrender\u001b[39m(\n\u001b[1;32m 179\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 180\u001b[0m height\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m240\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;28;01mNone\u001b[39;00m]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 190\u001b[0m ):\n\u001b[1;32m 191\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Returns a camera view as a NumPy array of pixel values.\u001b[39;00m\n\u001b[1;32m 192\u001b[0m \n\u001b[1;32m 193\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 221\u001b[0m \u001b[38;5;124;03m The rendered RGB, depth or segmentation image.\u001b[39;00m\n\u001b[1;32m 222\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 223\u001b[0m camera \u001b[38;5;241m=\u001b[39m \u001b[43mCamera\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 224\u001b[0m \u001b[43m \u001b[49m\u001b[43mphysics\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 225\u001b[0m \u001b[43m \u001b[49m\u001b[43mheight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 226\u001b[0m \u001b[43m \u001b[49m\u001b[43mwidth\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwidth\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 227\u001b[0m \u001b[43m \u001b[49m\u001b[43mcamera_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcamera_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 228\u001b[0m \u001b[43m \u001b[49m\u001b[43mscene_callback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mscene_callback\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 229\u001b[0m image \u001b[38;5;241m=\u001b[39m camera\u001b[38;5;241m.\u001b[39mrender(\n\u001b[1;32m 230\u001b[0m overlays\u001b[38;5;241m=\u001b[39moverlays, depth\u001b[38;5;241m=\u001b[39mdepth, segmentation\u001b[38;5;241m=\u001b[39msegmentation,\n\u001b[1;32m 231\u001b[0m scene_option\u001b[38;5;241m=\u001b[39mscene_option, render_flag_overrides\u001b[38;5;241m=\u001b[39mrender_flag_overrides)\n\u001b[1;32m 232\u001b[0m camera\u001b[38;5;241m.\u001b[39m_scene\u001b[38;5;241m.\u001b[39mfree() \u001b[38;5;66;03m# pylint: disable=protected-access\u001b[39;00m\n",
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/mujoco/engine.py:711\u001b[0m, in \u001b[0;36mCamera.__init__\u001b[0;34m(self, physics, height, width, camera_id, max_geom, scene_callback)\u001b[0m\n\u001b[1;32m 708\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_rgb_buffer \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mempty((\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_height, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_width, \u001b[38;5;241m3\u001b[39m), dtype\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39muint8)\n\u001b[1;32m 709\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_depth_buffer \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mempty((\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_height, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_width), dtype\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39mfloat32)\n\u001b[0;32m--> 711\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_physics\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcontexts\u001b[49m\u001b[38;5;241m.\u001b[39mmujoco \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 712\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_physics\u001b[38;5;241m.\u001b[39mcontexts\u001b[38;5;241m.\u001b[39mgl\u001b[38;5;241m.\u001b[39mmake_current() \u001b[38;5;28;01mas\u001b[39;00m ctx:\n\u001b[1;32m 713\u001b[0m ctx\u001b[38;5;241m.\u001b[39mcall(mujoco\u001b[38;5;241m.\u001b[39mmjr_setBuffer, mujoco\u001b[38;5;241m.\u001b[39mmjtFramebuffer\u001b[38;5;241m.\u001b[39mmjFB_OFFSCREEN,\n\u001b[1;32m 714\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_physics\u001b[38;5;241m.\u001b[39mcontexts\u001b[38;5;241m.\u001b[39mmujoco\u001b[38;5;241m.\u001b[39mptr)\n",
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/mujoco/engine.py:533\u001b[0m, in \u001b[0;36mPhysics.contexts\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 531\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_contexts_lock:\n\u001b[1;32m 532\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_contexts:\n\u001b[0;32m--> 533\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_make_rendering_contexts\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 534\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_contexts\n",
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/mujoco/engine.py:519\u001b[0m, in \u001b[0;36mPhysics._make_rendering_contexts\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 516\u001b[0m render_context \u001b[38;5;241m=\u001b[39m _render\u001b[38;5;241m.\u001b[39mRenderer(\n\u001b[1;32m 517\u001b[0m max_width\u001b[38;5;241m=\u001b[39mmax_width, max_height\u001b[38;5;241m=\u001b[39mmax_height)\n\u001b[1;32m 518\u001b[0m \u001b[38;5;66;03m# Create the MuJoCo context.\u001b[39;00m\n\u001b[0;32m--> 519\u001b[0m mujoco_context \u001b[38;5;241m=\u001b[39m \u001b[43mwrapper\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mMjrContext\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrender_context\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 520\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_contexts \u001b[38;5;241m=\u001b[39m Contexts(gl\u001b[38;5;241m=\u001b[39mrender_context, mujoco\u001b[38;5;241m=\u001b[39mmujoco_context)\n",
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/mujoco/wrapper/core.py:603\u001b[0m, in \u001b[0;36mMjrContext.__init__\u001b[0;34m(self, model, gl_context, font_scale)\u001b[0m\n\u001b[1;32m 601\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gl_context \u001b[38;5;241m=\u001b[39m gl_context\n\u001b[1;32m 602\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m gl_context\u001b[38;5;241m.\u001b[39mmake_current() \u001b[38;5;28;01mas\u001b[39;00m ctx:\n\u001b[0;32m--> 603\u001b[0m ptr \u001b[38;5;241m=\u001b[39m \u001b[43mctx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcall\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmujoco\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mMjrContext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mptr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfont_scale\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 604\u001b[0m ctx\u001b[38;5;241m.\u001b[39mcall(mujoco\u001b[38;5;241m.\u001b[39mmjr_setBuffer, mujoco\u001b[38;5;241m.\u001b[39mmjtFramebuffer\u001b[38;5;241m.\u001b[39mmjFB_OFFSCREEN, ptr)\n\u001b[1;32m 605\u001b[0m gl_context\u001b[38;5;241m.\u001b[39mkeep_alive(ptr)\n",
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/_render/executor/render_executor.py:138\u001b[0m, in \u001b[0;36mPassthroughRenderExecutor.call\u001b[0;34m(self, func, *args, **kwargs)\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mcall\u001b[39m(\u001b[38;5;28mself\u001b[39m, func, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 137\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_locked()\n\u001b[0;32m--> 138\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mFatalError\u001b[0m: gladLoadGL error"
]
}
],
"source": [
"# Reset the policy and environments to prepare for rollout\n",
"policy.reset()\n",
"numpy_observation, info = env.reset(seed=42)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "lerobot",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 2
}