cleanup
This commit is contained in:
parent
1329954dba
commit
d1bec3e8ae
|
@ -39,7 +39,6 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|||
from lerobot.common.datasets.video_utils import (
|
||||
decode_video_frames_torchvision,
|
||||
encode_video_frames,
|
||||
decode_video_frames_torchcodec,
|
||||
)
|
||||
from lerobot.common.utils.benchmark import TimeBenchmark
|
||||
|
||||
|
@ -68,6 +67,10 @@ def parse_int_or_none(value) -> int | None:
|
|||
def check_datasets_formats(repo_ids: list) -> None:
|
||||
for repo_id in repo_ids:
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
if dataset.video:
|
||||
raise ValueError(
|
||||
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
|
||||
)
|
||||
|
||||
|
||||
def get_directory_size(directory: Path) -> int:
|
||||
|
@ -152,10 +155,6 @@ def decode_video_frames(
|
|||
) -> torch.Tensor:
|
||||
if backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
elif backend in ["torchcodec-cpu", "torchcodec-gpu"]:
|
||||
# Only pass device once depending on the backend
|
||||
device = "cpu" if backend == "torchcodec-cpu" else "cuda"
|
||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, device=device)
|
||||
else:
|
||||
raise NotImplementedError(backend)
|
||||
|
||||
|
@ -189,7 +188,7 @@ def benchmark_decoding(
|
|||
original_frames = load_original_frames(imgs_dir, timestamps, fps)
|
||||
result["load_time_images_ms"] = time_benchmark.result_ms / num_frames
|
||||
|
||||
frames_np, original_frames_np = frames.cpu().numpy(), original_frames.cpu().numpy()
|
||||
frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
|
||||
for i in range(num_frames):
|
||||
result["mse_values"].append(mean_squared_error(original_frames_np[i], frames_np[i]))
|
||||
result["psnr_values"].append(
|
||||
|
|
|
@ -660,48 +660,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
item = {**item, **padding}
|
||||
for key, val in query_result.items():
|
||||
item[key] = val
|
||||
if len(self.meta.video_keys) > 0:
|
||||
current_ts = item["timestamp"].item()
|
||||
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
||||
# if what is returned is all the info that i used query_timestamps, episode
|
||||
# percentage of chance, 30% cpu, gpu
|
||||
# video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
# item = {**video_frames, **item}
|
||||
|
||||
# jade - instead of decoding video, return video path & timestamps
|
||||
# hack only add metadata
|
||||
item["video_paths"] = {
|
||||
vid_key: self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||
for vid_key in query_timestamps.keys()
|
||||
}
|
||||
item["query_timestamps"] = query_timestamps
|
||||
|
||||
if self.image_transforms is not None:
|
||||
breakpoint()
|
||||
image_keys = self.meta.camera_keys
|
||||
for cam in image_keys:
|
||||
item[cam] = self.image_transforms(item[cam])
|
||||
|
||||
# Add task as a string
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self.meta.tasks[task_idx]
|
||||
|
||||
return item
|
||||
def __getitem2__(self, idx) -> dict:
|
||||
item = self.hf_dataset[idx]
|
||||
ep_idx = item["episode_index"].item()
|
||||
|
||||
query_indices = None
|
||||
# data logic
|
||||
if self.delta_indices is not None:
|
||||
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
|
||||
query_indices, padding = self._get_query_indices(idx, current_ep_idx) #
|
||||
query_result = self._query_hf_dataset(query_indices)
|
||||
item = {**item, **padding}
|
||||
for key, val in query_result.items():
|
||||
item[key] = val
|
||||
|
||||
# video logic
|
||||
if len(self.meta.video_keys) > 0:
|
||||
current_ts = item["timestamp"].item()
|
||||
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
||||
|
@ -718,6 +677,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
item["task"] = self.meta.tasks[task_idx]
|
||||
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
feature_keys = list(self.features)
|
||||
return (
|
||||
|
|
|
@ -127,67 +127,6 @@ def decode_video_frames_torchvision(
|
|||
return closest_frames
|
||||
|
||||
|
||||
def decode_video_frames_torchcodec(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
device: str = "cpu",
|
||||
log_loaded_timestamps: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Loads frames associated with the requested timestamps of a video using torchcodec."""
|
||||
video_path = str(video_path)
|
||||
# initialize video decoder
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
decoder = VideoDecoder(video_path, device=device)
|
||||
loaded_frames = []
|
||||
loaded_ts = []
|
||||
# get metadata for frame information
|
||||
metadata = decoder.metadata
|
||||
average_fps = metadata.average_fps
|
||||
|
||||
# convert timestamps to frame indices
|
||||
frame_indices = [int(ts * average_fps) for ts in timestamps]
|
||||
|
||||
# retrieve frames based on indices
|
||||
frames_batch = decoder.get_frames_at(indices=frame_indices)
|
||||
|
||||
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds):
|
||||
loaded_frames.append(frame)
|
||||
loaded_ts.append(pts.item())
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"Frame loaded at timestamp={pts:.4f}")
|
||||
|
||||
query_ts = torch.tensor(timestamps)
|
||||
loaded_ts = torch.tensor(loaded_ts)
|
||||
|
||||
# compute distances between each query timestamp and loaded timestamps
|
||||
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
assert is_within_tol.all(), (
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
"To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||
closest_ts = loaded_ts[argmin_]
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"{closest_ts=}")
|
||||
|
||||
# convert to float32 in [0,1] range (channel first)
|
||||
closest_frames = closest_frames.type(torch.float32) / 255
|
||||
|
||||
assert len(timestamps) == len(closest_frames)
|
||||
return closest_frames
|
||||
|
||||
def encode_video_frames(
|
||||
imgs_dir: Path | str,
|
||||
video_path: Path | str,
|
||||
|
|
|
@ -23,7 +23,7 @@ import torch
|
|||
from termcolor import colored
|
||||
from torch.amp import GradScaler
|
||||
from torch.optim import Optimizer
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
|
@ -51,58 +51,7 @@ from lerobot.common.utils.wandb_utils import WandBLogger
|
|||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
decode_video_frames_torchvision, decode_video_frames_torchcodec
|
||||
)
|
||||
# let's define a custom fn
|
||||
|
||||
def custom_collate_fn(batch):
|
||||
# always in the cuda, getitem is on cpu,
|
||||
# then implement mixed
|
||||
"""
|
||||
Custom collate function that decodes videos on GPU/CPU.
|
||||
Converts the batch to a dictionary with keys representing each field.
|
||||
Returns a tensor for video frames instead of a list.
|
||||
"""
|
||||
# know when it is called
|
||||
final_batch = {}
|
||||
is_main_process = torch.utils.data.get_worker_info() is None
|
||||
|
||||
# the batch is given as a list, we need to return a dict
|
||||
for item in batch:
|
||||
# process video decoding for each item
|
||||
if "video_paths" in item and "query_timestamps" in item:
|
||||
for vid_key, video_path in item["video_paths"].items():
|
||||
# decode video frames based on timestamps
|
||||
timestamps = item["query_timestamps"][vid_key]
|
||||
|
||||
# ✅ Use CUDA only in the main process
|
||||
device = "cuda" if is_main_process else "cpu"
|
||||
frames = decode_video_frames_torchcodec(
|
||||
video_path=Path(video_path),
|
||||
timestamps=timestamps,
|
||||
tolerance_s=0.02,
|
||||
# backend="pyav",
|
||||
log_loaded_timestamps=False,
|
||||
device=device, # ✅ Keeps CUDA safe
|
||||
)
|
||||
# stack frames for this video key and add directly to the item
|
||||
item[vid_key] = frames
|
||||
|
||||
# add item data (both video and non-video) to final_batch
|
||||
for key, value in item.items():
|
||||
if key not in final_batch:
|
||||
final_batch[key] = []
|
||||
final_batch[key].append(value)
|
||||
|
||||
# now, stack tensors for each key in final_batch
|
||||
# this is needed to ensure that video frames (and any other tensor fields) are combined
|
||||
# into a single tensor per field, rather than a list of tensors!
|
||||
for key in final_batch:
|
||||
if isinstance(final_batch[key][0], torch.Tensor):
|
||||
final_batch[key] = torch.stack(final_batch[key]) # stack tensors if needed
|
||||
|
||||
return final_batch
|
||||
|
||||
def update_policy(
|
||||
train_metrics: MetricsTracker,
|
||||
|
@ -233,11 +182,12 @@ def train(cfg: TrainPipelineConfig):
|
|||
shuffle=shuffle,
|
||||
sampler=sampler,
|
||||
pin_memory=device.type != "cpu",
|
||||
collate_fn=custom_collate_fn,
|
||||
drop_last=False,
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
policy.train()
|
||||
|
||||
train_metrics = {
|
||||
"loss": AverageMeter("loss", ":.3f"),
|
||||
"grad_norm": AverageMeter("grdn", ":.3f"),
|
||||
|
@ -255,6 +205,7 @@ def train(cfg: TrainPipelineConfig):
|
|||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].to(device, non_blocking=True)
|
||||
|
@ -280,7 +231,6 @@ def train(cfg: TrainPipelineConfig):
|
|||
|
||||
if is_log_step:
|
||||
logging.info(train_tracker)
|
||||
breakpoint()
|
||||
if wandb_logger:
|
||||
wandb_log_dict = train_tracker.to_dict()
|
||||
if output_dict:
|
||||
|
|
Binary file not shown.
674
tester.ipynb
674
tester.ipynb
|
@ -1,674 +0,0 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"import gym_pusht # noqa: F401\n",
|
||||
"import gymnasium as gym\n",
|
||||
"import imageio\n",
|
||||
"import numpy\n",
|
||||
"import torch"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Select your device\n",
|
||||
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Provide the [hugging face repo id](https://huggingface.co/lerobot/diffusion_pusht):\n",
|
||||
"pretrained_policy_path = \"IliaLarchenko/dot_pusht_keypoints\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from lerobot.common.policies.dot.modeling_dot import DOTPolicy\n",
|
||||
"policy = DOTPolicy.from_pretrained(pretrained_policy_path, map_location=device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env = gym.make(\n",
|
||||
" \"gym_pusht/PushT-v0\",\n",
|
||||
" obs_type=\"environment_state_agent_pos\",\n",
|
||||
" max_episode_steps=300,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'observation.state': PolicyFeature(type=<FeatureType.STATE: 'STATE'>, shape=(2,)), 'observation.environment_state': PolicyFeature(type=<FeatureType.ENV: 'ENV'>, shape=(16,))}\n",
|
||||
"Dict('agent_pos': Box(0.0, 512.0, (2,), float64), 'environment_state': Box(0.0, 512.0, (16,), float64))\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(policy.config.input_features)\n",
|
||||
"print(env.observation_space)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'action': PolicyFeature(type=<FeatureType.ACTION: 'ACTION'>, shape=(2,))}\n",
|
||||
"Box(0.0, 512.0, (2,), float32)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(policy.config.output_features)\n",
|
||||
"print(env.action_space)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"policy.reset()\n",
|
||||
"numpy_observation, info = env.reset(seed=42)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Prepare to collect every rewards and all the frames of the episode,\n",
|
||||
"# from initial state to final state.\n",
|
||||
"rewards = []\n",
|
||||
"frames = []\n",
|
||||
"\n",
|
||||
"# Render frame of the initial state\n",
|
||||
"frames.append(env.render())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"step=0 reward=np.float64(0.0) terminated=False\n",
|
||||
"step=1 reward=np.float64(0.0) terminated=False\n",
|
||||
"step=2 reward=np.float64(0.0) terminated=False\n",
|
||||
"step=3 reward=np.float64(0.0) terminated=False\n",
|
||||
"step=4 reward=np.float64(0.0) terminated=False\n",
|
||||
"step=5 reward=np.float64(0.0) terminated=False\n",
|
||||
"step=6 reward=np.float64(0.0) terminated=False\n",
|
||||
"step=7 reward=np.float64(0.0) terminated=False\n",
|
||||
"step=8 reward=np.float64(0.0) terminated=False\n",
|
||||
"step=9 reward=np.float64(0.0) terminated=False\n",
|
||||
"step=10 reward=np.float64(0.0) terminated=False\n",
|
||||
"step=11 reward=np.float64(0.0) terminated=False\n",
|
||||
"step=12 reward=np.float64(0.0) terminated=False\n",
|
||||
"step=13 reward=np.float64(0.0) terminated=False\n",
|
||||
"step=14 reward=np.float64(0.0) terminated=False\n",
|
||||
"step=15 reward=np.float64(0.0) terminated=False\n",
|
||||
"step=16 reward=np.float64(0.0) terminated=False\n",
|
||||
"step=17 reward=np.float64(0.0) terminated=False\n",
|
||||
"step=18 reward=np.float64(0.0) terminated=False\n",
|
||||
"step=19 reward=np.float64(0.0) terminated=False\n",
|
||||
"step=20 reward=np.float64(0.0) terminated=False\n",
|
||||
"step=21 reward=np.float64(0.0) terminated=False\n",
|
||||
"step=22 reward=np.float64(0.0009941544780861455) terminated=False\n",
|
||||
"step=23 reward=np.float64(0.033647507519038757) terminated=False\n",
|
||||
"step=24 reward=np.float64(0.07026086006261555) terminated=False\n",
|
||||
"step=25 reward=np.float64(0.10069667553409196) terminated=False\n",
|
||||
"step=26 reward=np.float64(0.11389926069925992) terminated=False\n",
|
||||
"step=27 reward=np.float64(0.12027077768723497) terminated=False\n",
|
||||
"step=28 reward=np.float64(0.12486582623684722) terminated=False\n",
|
||||
"step=29 reward=np.float64(0.12815916861048604) terminated=False\n",
|
||||
"step=30 reward=np.float64(0.1303391815805222) terminated=False\n",
|
||||
"step=31 reward=np.float64(0.1315231117258188) terminated=False\n",
|
||||
"step=32 reward=np.float64(0.13221640549835664) terminated=False\n",
|
||||
"step=33 reward=np.float64(0.13254763259209015) terminated=False\n",
|
||||
"step=34 reward=np.float64(0.13263558368425837) terminated=False\n",
|
||||
"step=35 reward=np.float64(0.13263932937572084) terminated=False\n",
|
||||
"step=36 reward=np.float64(0.13263932937572084) terminated=False\n",
|
||||
"step=37 reward=np.float64(0.13263932937572084) terminated=False\n",
|
||||
"step=38 reward=np.float64(0.13263932937572084) terminated=False\n",
|
||||
"step=39 reward=np.float64(0.13263932937572084) terminated=False\n",
|
||||
"step=40 reward=np.float64(0.13263932937572084) terminated=False\n",
|
||||
"step=41 reward=np.float64(0.13263932937572084) terminated=False\n",
|
||||
"step=42 reward=np.float64(0.13263932937572084) terminated=False\n",
|
||||
"step=43 reward=np.float64(0.13263932937572084) terminated=False\n",
|
||||
"step=44 reward=np.float64(0.13263932937572084) terminated=False\n",
|
||||
"step=45 reward=np.float64(0.13263932937572084) terminated=False\n",
|
||||
"step=46 reward=np.float64(0.13263932937572084) terminated=False\n",
|
||||
"step=47 reward=np.float64(0.13263932937572084) terminated=False\n",
|
||||
"step=48 reward=np.float64(0.13263932937572084) terminated=False\n",
|
||||
"step=49 reward=np.float64(0.13263932937572084) terminated=False\n",
|
||||
"step=50 reward=np.float64(0.13263932937572084) terminated=False\n",
|
||||
"step=51 reward=np.float64(0.13263932937572084) terminated=False\n",
|
||||
"step=52 reward=np.float64(0.14872285364307145) terminated=False\n",
|
||||
"step=53 reward=np.float64(0.19847005044261715) terminated=False\n",
|
||||
"step=54 reward=np.float64(0.24338272205852812) terminated=False\n",
|
||||
"step=55 reward=np.float64(0.2667243347061481) terminated=False\n",
|
||||
"step=56 reward=np.float64(0.2691675276421592) terminated=False\n",
|
||||
"step=57 reward=np.float64(0.3018254762158707) terminated=False\n",
|
||||
"step=58 reward=np.float64(0.3613501686331564) terminated=False\n",
|
||||
"step=59 reward=np.float64(0.4613512243665896) terminated=False\n",
|
||||
"step=60 reward=np.float64(0.5617656688929643) terminated=False\n",
|
||||
"step=61 reward=np.float64(0.5747609180351871) terminated=False\n",
|
||||
"step=62 reward=np.float64(0.507048651118485) terminated=False\n",
|
||||
"step=63 reward=np.float64(0.44332042287270484) terminated=False\n",
|
||||
"step=64 reward=np.float64(0.3993804222553378) terminated=False\n",
|
||||
"step=65 reward=np.float64(0.36941592278664487) terminated=False\n",
|
||||
"step=66 reward=np.float64(0.36941592278664487) terminated=False\n",
|
||||
"step=67 reward=np.float64(0.36941592278664487) terminated=False\n",
|
||||
"step=68 reward=np.float64(0.36941592278664487) terminated=False\n",
|
||||
"step=69 reward=np.float64(0.36941592278664487) terminated=False\n",
|
||||
"step=70 reward=np.float64(0.36941592278664487) terminated=False\n",
|
||||
"step=71 reward=np.float64(0.36941592278664487) terminated=False\n",
|
||||
"step=72 reward=np.float64(0.36941592278664487) terminated=False\n",
|
||||
"step=73 reward=np.float64(0.36941592278664487) terminated=False\n",
|
||||
"step=74 reward=np.float64(0.36941592278664487) terminated=False\n",
|
||||
"step=75 reward=np.float64(0.4322328474940646) terminated=False\n",
|
||||
"step=76 reward=np.float64(0.4818152566968738) terminated=False\n",
|
||||
"step=77 reward=np.float64(0.5252535051167734) terminated=False\n",
|
||||
"step=78 reward=np.float64(0.5586446249197407) terminated=False\n",
|
||||
"step=79 reward=np.float64(0.5885022076599307) terminated=False\n",
|
||||
"step=80 reward=np.float64(0.5977994643852952) terminated=False\n",
|
||||
"step=81 reward=np.float64(0.597859201570885) terminated=False\n",
|
||||
"step=82 reward=np.float64(0.597859201570885) terminated=False\n",
|
||||
"step=83 reward=np.float64(0.597859201570885) terminated=False\n",
|
||||
"step=84 reward=np.float64(0.597859201570885) terminated=False\n",
|
||||
"step=85 reward=np.float64(0.597859201570885) terminated=False\n",
|
||||
"step=86 reward=np.float64(0.597859201570885) terminated=False\n",
|
||||
"step=87 reward=np.float64(0.597859201570885) terminated=False\n",
|
||||
"step=88 reward=np.float64(0.597859201570885) terminated=False\n",
|
||||
"step=89 reward=np.float64(0.6876341127908622) terminated=False\n",
|
||||
"step=90 reward=np.float64(0.8166289152424572) terminated=False\n",
|
||||
"step=91 reward=np.float64(0.9421614978354362) terminated=False\n",
|
||||
"step=92 reward=np.float64(0.9441608976568224) terminated=False\n",
|
||||
"step=93 reward=np.float64(0.9104163604296966) terminated=False\n",
|
||||
"step=94 reward=np.float64(0.909182661371769) terminated=False\n",
|
||||
"step=95 reward=np.float64(0.909182661371769) terminated=False\n",
|
||||
"step=96 reward=np.float64(0.909182661371769) terminated=False\n",
|
||||
"step=97 reward=np.float64(0.909182661371769) terminated=False\n",
|
||||
"step=98 reward=np.float64(0.909182661371769) terminated=False\n",
|
||||
"step=99 reward=np.float64(0.909182661371769) terminated=False\n",
|
||||
"step=100 reward=np.float64(0.909182661371769) terminated=False\n",
|
||||
"step=101 reward=np.float64(0.909182661371769) terminated=False\n",
|
||||
"step=102 reward=np.float64(0.909182661371769) terminated=False\n",
|
||||
"step=103 reward=np.float64(0.9340357871805705) terminated=False\n",
|
||||
"step=104 reward=np.float64(0.8851102121651142) terminated=False\n",
|
||||
"step=105 reward=np.float64(0.8809768749693764) terminated=False\n",
|
||||
"step=106 reward=np.float64(0.8809768749693764) terminated=False\n",
|
||||
"step=107 reward=np.float64(0.8809768749693764) terminated=False\n",
|
||||
"step=108 reward=np.float64(0.8809768749693764) terminated=False\n",
|
||||
"step=109 reward=np.float64(0.8809768749693764) terminated=False\n",
|
||||
"step=110 reward=np.float64(0.8809768749693764) terminated=False\n",
|
||||
"step=111 reward=np.float64(0.8809768749693764) terminated=False\n",
|
||||
"step=112 reward=np.float64(0.8809768749693764) terminated=False\n",
|
||||
"step=113 reward=np.float64(0.8809768749693764) terminated=False\n",
|
||||
"step=114 reward=np.float64(0.8809768749693764) terminated=False\n",
|
||||
"step=115 reward=np.float64(0.8809768749693764) terminated=False\n",
|
||||
"step=116 reward=np.float64(0.8809768749693764) terminated=False\n",
|
||||
"step=117 reward=np.float64(0.8809768749693764) terminated=False\n",
|
||||
"step=118 reward=np.float64(0.8809768749693764) terminated=False\n",
|
||||
"step=119 reward=np.float64(0.9518089158714292) terminated=False\n",
|
||||
"step=120 reward=np.float64(0.9405458729516311) terminated=False\n",
|
||||
"step=121 reward=np.float64(0.935511214687435) terminated=False\n",
|
||||
"step=122 reward=np.float64(0.935511214687435) terminated=False\n",
|
||||
"step=123 reward=np.float64(0.935511214687435) terminated=False\n",
|
||||
"step=124 reward=np.float64(0.935511214687435) terminated=False\n",
|
||||
"step=125 reward=np.float64(0.935511214687435) terminated=False\n",
|
||||
"step=126 reward=np.float64(0.935511214687435) terminated=False\n",
|
||||
"step=127 reward=np.float64(0.935511214687435) terminated=False\n",
|
||||
"step=128 reward=np.float64(0.935511214687435) terminated=False\n",
|
||||
"step=129 reward=np.float64(0.935511214687435) terminated=False\n",
|
||||
"step=130 reward=np.float64(0.935511214687435) terminated=False\n",
|
||||
"step=131 reward=np.float64(0.935511214687435) terminated=False\n",
|
||||
"step=132 reward=np.float64(0.935511214687435) terminated=False\n",
|
||||
"step=133 reward=np.float64(0.935511214687435) terminated=False\n",
|
||||
"step=134 reward=np.float64(0.935511214687435) terminated=False\n",
|
||||
"step=135 reward=np.float64(0.9534990217822209) terminated=False\n",
|
||||
"step=136 reward=np.float64(0.9596585109597399) terminated=False\n",
|
||||
"step=137 reward=np.float64(0.882875733420291) terminated=False\n",
|
||||
"step=138 reward=np.float64(0.8277880190838034) terminated=False\n",
|
||||
"step=139 reward=np.float64(0.8211529871155266) terminated=False\n",
|
||||
"step=140 reward=np.float64(0.8211529871155266) terminated=False\n",
|
||||
"step=141 reward=np.float64(0.8211529871155266) terminated=False\n",
|
||||
"step=142 reward=np.float64(0.8211529871155266) terminated=False\n",
|
||||
"step=143 reward=np.float64(0.8211529871155266) terminated=False\n",
|
||||
"step=144 reward=np.float64(0.8211529871155266) terminated=False\n",
|
||||
"step=145 reward=np.float64(0.8211529871155266) terminated=False\n",
|
||||
"step=146 reward=np.float64(0.8211529871155266) terminated=False\n",
|
||||
"step=147 reward=np.float64(0.8211529871155266) terminated=False\n",
|
||||
"step=148 reward=np.float64(0.8211529871155266) terminated=False\n",
|
||||
"step=149 reward=np.float64(0.8211529871155266) terminated=False\n",
|
||||
"step=150 reward=np.float64(0.8211529871155266) terminated=False\n",
|
||||
"step=151 reward=np.float64(0.8211529871155266) terminated=False\n",
|
||||
"step=152 reward=np.float64(0.8211529871155266) terminated=False\n",
|
||||
"step=153 reward=np.float64(0.8211529871155266) terminated=False\n",
|
||||
"step=154 reward=np.float64(0.856408395120982) terminated=False\n",
|
||||
"step=155 reward=np.float64(0.9304040416833055) terminated=False\n",
|
||||
"step=156 reward=np.float64(0.9770812279113622) terminated=False\n",
|
||||
"step=157 reward=np.float64(0.9944597232685968) terminated=False\n",
|
||||
"step=158 reward=np.float64(0.9944597232685968) terminated=False\n",
|
||||
"step=159 reward=np.float64(0.9944597232685968) terminated=False\n",
|
||||
"step=160 reward=np.float64(0.9944597232685968) terminated=False\n",
|
||||
"step=161 reward=np.float64(0.9944597232685968) terminated=False\n",
|
||||
"step=162 reward=np.float64(0.9944597232685968) terminated=False\n",
|
||||
"step=163 reward=np.float64(0.9944597232685968) terminated=False\n",
|
||||
"step=164 reward=np.float64(0.9944597232685968) terminated=False\n",
|
||||
"step=165 reward=np.float64(0.9944597232685968) terminated=False\n",
|
||||
"step=166 reward=np.float64(0.9944597232685968) terminated=False\n",
|
||||
"step=167 reward=np.float64(0.9944597232685968) terminated=False\n",
|
||||
"step=168 reward=np.float64(0.9944597232685968) terminated=False\n",
|
||||
"step=169 reward=np.float64(0.9820541217675358) terminated=False\n",
|
||||
"step=170 reward=np.float64(0.8765528119646949) terminated=False\n",
|
||||
"step=171 reward=np.float64(0.8231919366320396) terminated=False\n",
|
||||
"step=172 reward=np.float64(0.7926155231821123) terminated=False\n",
|
||||
"step=173 reward=np.float64(0.7902960563492054) terminated=False\n",
|
||||
"step=174 reward=np.float64(0.7902960563492054) terminated=False\n",
|
||||
"step=175 reward=np.float64(0.7902960563492054) terminated=False\n",
|
||||
"step=176 reward=np.float64(0.7902960563492054) terminated=False\n",
|
||||
"step=177 reward=np.float64(0.7902960563492054) terminated=False\n",
|
||||
"step=178 reward=np.float64(0.7902960563492054) terminated=False\n",
|
||||
"step=179 reward=np.float64(0.7902960563492054) terminated=False\n",
|
||||
"step=180 reward=np.float64(0.7902960563492054) terminated=False\n",
|
||||
"step=181 reward=np.float64(0.8158199658870418) terminated=False\n",
|
||||
"step=182 reward=np.float64(0.8191627090126786) terminated=False\n",
|
||||
"step=183 reward=np.float64(0.8172224839948001) terminated=False\n",
|
||||
"step=184 reward=np.float64(0.8172224839948001) terminated=False\n",
|
||||
"step=185 reward=np.float64(0.8172224839948001) terminated=False\n",
|
||||
"step=186 reward=np.float64(0.8172224839948001) terminated=False\n",
|
||||
"step=187 reward=np.float64(0.8172224839948001) terminated=False\n",
|
||||
"step=188 reward=np.float64(0.8172224839948001) terminated=False\n",
|
||||
"step=189 reward=np.float64(0.8172224839948001) terminated=False\n",
|
||||
"step=190 reward=np.float64(0.8172224839948001) terminated=False\n",
|
||||
"step=191 reward=np.float64(0.8172224839948001) terminated=False\n",
|
||||
"step=192 reward=np.float64(0.8172224839948001) terminated=False\n",
|
||||
"step=193 reward=np.float64(0.8172224839948001) terminated=False\n",
|
||||
"step=194 reward=np.float64(0.8172224839948001) terminated=False\n",
|
||||
"step=195 reward=np.float64(0.8172224839948001) terminated=False\n",
|
||||
"step=196 reward=np.float64(0.8172224839948001) terminated=False\n",
|
||||
"step=197 reward=np.float64(0.878735078350138) terminated=False\n",
|
||||
"step=198 reward=np.float64(0.8564816396314117) terminated=False\n",
|
||||
"step=199 reward=np.float64(0.7970005244772627) terminated=False\n",
|
||||
"step=200 reward=np.float64(0.7688729860960439) terminated=False\n",
|
||||
"step=201 reward=np.float64(0.7671148163486476) terminated=False\n",
|
||||
"step=202 reward=np.float64(0.7671148163486476) terminated=False\n",
|
||||
"step=203 reward=np.float64(0.7671148163486476) terminated=False\n",
|
||||
"step=204 reward=np.float64(0.7671148163486476) terminated=False\n",
|
||||
"step=205 reward=np.float64(0.7671148163486476) terminated=False\n",
|
||||
"step=206 reward=np.float64(0.7671148163486476) terminated=False\n",
|
||||
"step=207 reward=np.float64(0.7671148163486476) terminated=False\n",
|
||||
"step=208 reward=np.float64(0.7671148163486476) terminated=False\n",
|
||||
"step=209 reward=np.float64(0.7671148163486476) terminated=False\n",
|
||||
"step=210 reward=np.float64(0.7671148163486476) terminated=False\n",
|
||||
"step=211 reward=np.float64(0.7671148163486476) terminated=False\n",
|
||||
"step=212 reward=np.float64(0.7671148163486476) terminated=False\n",
|
||||
"step=213 reward=np.float64(0.7671148163486476) terminated=False\n",
|
||||
"step=214 reward=np.float64(0.7671148163486476) terminated=False\n",
|
||||
"step=215 reward=np.float64(0.7671148163486476) terminated=False\n",
|
||||
"step=216 reward=np.float64(0.7671148163486476) terminated=False\n",
|
||||
"step=217 reward=np.float64(0.8045352949993082) terminated=False\n",
|
||||
"step=218 reward=np.float64(0.8328184612705187) terminated=False\n",
|
||||
"step=219 reward=np.float64(0.8558996801195216) terminated=False\n",
|
||||
"step=220 reward=np.float64(0.8576887923798919) terminated=False\n",
|
||||
"step=221 reward=np.float64(0.8576887923798919) terminated=False\n",
|
||||
"step=222 reward=np.float64(0.8576887923798919) terminated=False\n",
|
||||
"step=223 reward=np.float64(0.8576887923798919) terminated=False\n",
|
||||
"step=224 reward=np.float64(0.8576887923798919) terminated=False\n",
|
||||
"step=225 reward=np.float64(0.8576887923798919) terminated=False\n",
|
||||
"step=226 reward=np.float64(0.8576887923798919) terminated=False\n",
|
||||
"step=227 reward=np.float64(0.8576887923798919) terminated=False\n",
|
||||
"step=228 reward=np.float64(0.8576887923798919) terminated=False\n",
|
||||
"step=229 reward=np.float64(0.8576887923798919) terminated=False\n",
|
||||
"step=230 reward=np.float64(0.8576887923798919) terminated=False\n",
|
||||
"step=231 reward=np.float64(0.8576887923798919) terminated=False\n",
|
||||
"step=232 reward=np.float64(0.8576887923798919) terminated=False\n",
|
||||
"step=233 reward=np.float64(0.8576887923798919) terminated=False\n",
|
||||
"step=234 reward=np.float64(0.935454295639811) terminated=False\n",
|
||||
"step=235 reward=np.float64(0.9874094853870982) terminated=False\n",
|
||||
"step=236 reward=np.float64(0.9719847917595543) terminated=False\n",
|
||||
"step=237 reward=np.float64(0.9719847917595543) terminated=False\n",
|
||||
"step=238 reward=np.float64(0.9719847917595543) terminated=False\n",
|
||||
"step=239 reward=np.float64(0.9719847917595543) terminated=False\n",
|
||||
"step=240 reward=np.float64(0.9719847917595543) terminated=False\n",
|
||||
"step=241 reward=np.float64(0.9719847917595543) terminated=False\n",
|
||||
"step=242 reward=np.float64(0.9719847917595543) terminated=False\n",
|
||||
"step=243 reward=np.float64(0.9719847917595543) terminated=False\n",
|
||||
"step=244 reward=np.float64(0.9719847917595543) terminated=False\n",
|
||||
"step=245 reward=np.float64(0.9719847917595543) terminated=False\n",
|
||||
"step=246 reward=np.float64(0.9719847917595543) terminated=False\n",
|
||||
"step=247 reward=np.float64(0.9719847917595543) terminated=False\n",
|
||||
"step=248 reward=np.float64(0.9719847917595543) terminated=False\n",
|
||||
"step=249 reward=np.float64(0.9719847917595543) terminated=False\n",
|
||||
"step=250 reward=np.float64(0.9719847917595543) terminated=False\n",
|
||||
"step=251 reward=np.float64(0.9719847917595543) terminated=False\n",
|
||||
"step=252 reward=np.float64(0.9719847917595543) terminated=False\n",
|
||||
"step=253 reward=np.float64(0.9719847917595543) terminated=False\n",
|
||||
"step=254 reward=np.float64(0.9727790631955697) terminated=False\n",
|
||||
"step=255 reward=np.float64(0.946125141646605) terminated=False\n",
|
||||
"step=256 reward=np.float64(0.9368755165399575) terminated=False\n",
|
||||
"step=257 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=258 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=259 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=260 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=261 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=262 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=263 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=264 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=265 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=266 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=267 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=268 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=269 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=270 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=271 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=272 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=273 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=274 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=275 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=276 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=277 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=278 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=279 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=280 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=281 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=282 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=283 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=284 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=285 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=286 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=287 reward=np.float64(0.9360986686274937) terminated=False\n",
|
||||
"step=288 reward=np.float64(0.9316550466986755) terminated=False\n",
|
||||
"step=289 reward=np.float64(0.9218676877631473) terminated=False\n",
|
||||
"step=290 reward=np.float64(0.9213441513220694) terminated=False\n",
|
||||
"step=291 reward=np.float64(0.9213441513220694) terminated=False\n",
|
||||
"step=292 reward=np.float64(0.9213441513220694) terminated=False\n",
|
||||
"step=293 reward=np.float64(0.9213441513220694) terminated=False\n",
|
||||
"step=294 reward=np.float64(0.9213441513220694) terminated=False\n",
|
||||
"step=295 reward=np.float64(0.9213441513220694) terminated=False\n",
|
||||
"step=296 reward=np.float64(0.9213441513220694) terminated=False\n",
|
||||
"step=297 reward=np.float64(0.9213441513220694) terminated=False\n",
|
||||
"step=298 reward=np.float64(0.9213441513220694) terminated=False\n",
|
||||
"step=299 reward=np.float64(0.9213441513220694) terminated=False\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"step = 0\n",
|
||||
"done = False\n",
|
||||
"\n",
|
||||
"while not done:\n",
|
||||
" # Prepare observation for the policy\n",
|
||||
" state = torch.from_numpy(numpy_observation[\"agent_pos\"]) # Agent position\n",
|
||||
" env_state = torch.from_numpy(numpy_observation[\"environment_state\"]) # Environment state\n",
|
||||
"\n",
|
||||
" # Convert to float32\n",
|
||||
" state = state.to(torch.float32)\n",
|
||||
" env_state = env_state.to(torch.float32)\n",
|
||||
"\n",
|
||||
" # Send data tensors from CPU to GPU\n",
|
||||
" state = state.to(device, non_blocking=True)\n",
|
||||
" env_state = env_state.to(device, non_blocking=True)\n",
|
||||
"\n",
|
||||
" # Add extra (empty) batch dimension, required to forward the policy\n",
|
||||
" state = state.unsqueeze(0)\n",
|
||||
" env_state = env_state.unsqueeze(0)\n",
|
||||
"\n",
|
||||
" # Create the policy input dictionary\n",
|
||||
" observation = {\n",
|
||||
" \"observation.state\": state,\n",
|
||||
" \"observation.environment_state\": env_state, # Add environment_state here\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" # Predict the next action with respect to the current observation\n",
|
||||
" with torch.inference_mode():\n",
|
||||
" action = policy.select_action(observation)\n",
|
||||
"\n",
|
||||
" # Prepare the action for the environment\n",
|
||||
" numpy_action = action.squeeze(0).to(\"cpu\").numpy()\n",
|
||||
"\n",
|
||||
" # Step through the environment and receive a new observation\n",
|
||||
" numpy_observation, reward, terminated, truncated, info = env.step(numpy_action)\n",
|
||||
" print(f\"{step=} {reward=} {terminated=}\")\n",
|
||||
"\n",
|
||||
" # Keep track of all the rewards and frames\n",
|
||||
" rewards.append(reward)\n",
|
||||
" frames.append(env.render())\n",
|
||||
"\n",
|
||||
" # The rollout is considered done when the success state is reached (i.e. terminated is True),\n",
|
||||
" # or the maximum number of iterations is reached (i.e. truncated is True)\n",
|
||||
" done = terminated or truncated or done\n",
|
||||
" step += 1\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Failure!\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (680, 680) to (688, 688) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to 1 (risking incompatibility).\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Video of the evaluation is available in '/home/lerobot/output/rollout.mp4'.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"if terminated:\n",
|
||||
" print(\"Success!\")\n",
|
||||
"else:\n",
|
||||
" print(\"Failure!\")\n",
|
||||
"\n",
|
||||
"# Get the speed of environment (i.e. its number of frames per second).\n",
|
||||
"fps = env.metadata[\"render_fps\"]\n",
|
||||
"\n",
|
||||
"# Encode all frames into a mp4 video.\n",
|
||||
"video_path = \"/home/lerobot/output/rollout.mp4\"\n",
|
||||
"imageio.mimsave(str(video_path), numpy.stack(frames), fps=fps)\n",
|
||||
"\n",
|
||||
"print(f\"Video of the evaluation is available in '{video_path}'.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#now on aloha\n",
|
||||
"import imageio\n",
|
||||
"import gymnasium as gym\n",
|
||||
"import numpy as np\n",
|
||||
"import gym_aloha\n",
|
||||
"env = gym.make(\n",
|
||||
" \"gym_aloha/AlohaInsertion-v0\",\n",
|
||||
" obs_type=\"pixels\",\n",
|
||||
" max_episode_steps=300,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/opt/conda/envs/lerobot/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n",
|
||||
"Downloading: \"https://download.pytorch.org/models/resnet18-f37072fd.pth\" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth\n",
|
||||
"100%|██████████| 44.7M/44.7M [00:00<00:00, 172MB/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from lerobot.common.policies.dot.modeling_dot import DOTPolicy\n",
|
||||
"pretrained_policy_path = \"IliaLarchenko/dot_bimanual_insert\"\n",
|
||||
"policy = DOTPolicy.from_pretrained(pretrained_policy_path, map_location=device)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'observation.images.top': PolicyFeature(type=<FeatureType.VISUAL: 'VISUAL'>, shape=(3, 480, 640)), 'observation.state': PolicyFeature(type=<FeatureType.STATE: 'STATE'>, shape=(14,))}\n",
|
||||
"Dict('top': Box(0, 255, (480, 640, 3), uint8))\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# We can verify that the shapes of the features expected by the policy match the ones from the observations\n",
|
||||
"# produced by the environment\n",
|
||||
"print(policy.config.input_features)\n",
|
||||
"print(env.observation_space)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'action': PolicyFeature(type=<FeatureType.ACTION: 'ACTION'>, shape=(14,))}\n",
|
||||
"Box(-1.0, 1.0, (14,), float32)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Similarly, we can check that the actions produced by the policy will match the actions expected by the\n",
|
||||
"# environment\n",
|
||||
"print(policy.config.output_features)\n",
|
||||
"print(env.action_space)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "FatalError",
|
||||
"evalue": "gladLoadGL error",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mFatalError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[16], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Reset the policy and environments to prepare for rollout\u001b[39;00m\n\u001b[1;32m 2\u001b[0m policy\u001b[38;5;241m.\u001b[39mreset()\n\u001b[0;32m----> 3\u001b[0m numpy_observation, info \u001b[38;5;241m=\u001b[39m \u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m(\u001b[49m\u001b[43mseed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m42\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gymnasium/wrappers/time_limit.py:75\u001b[0m, in \u001b[0;36mTimeLimit.reset\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 66\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero.\u001b[39;00m\n\u001b[1;32m 67\u001b[0m \n\u001b[1;32m 68\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;124;03m The reset environment\u001b[39;00m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_elapsed_steps \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m---> 75\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gymnasium/wrappers/order_enforcing.py:61\u001b[0m, in \u001b[0;36mOrderEnforcing.reset\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Resets the environment with `kwargs`.\"\"\"\u001b[39;00m\n\u001b[1;32m 60\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_has_reset \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m---> 61\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gymnasium/wrappers/env_checker.py:57\u001b[0m, in \u001b[0;36mPassiveEnvChecker.reset\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchecked_reset:\n\u001b[1;32m 56\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchecked_reset \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m---> 57\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43menv_reset_passive_checker\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 58\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 59\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39menv\u001b[38;5;241m.\u001b[39mreset(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
|
||||
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gymnasium/utils/passive_env_checker.py:186\u001b[0m, in \u001b[0;36menv_reset_passive_checker\u001b[0;34m(env, **kwargs)\u001b[0m\n\u001b[1;32m 181\u001b[0m logger\u001b[38;5;241m.\u001b[39mdeprecation(\n\u001b[1;32m 182\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCurrent gymnasium version requires that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 183\u001b[0m )\n\u001b[1;32m 185\u001b[0m \u001b[38;5;66;03m# Checks the result of env.reset with kwargs\u001b[39;00m\n\u001b[0;32m--> 186\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(result, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[1;32m 189\u001b[0m logger\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[1;32m 190\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(result)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m`\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 191\u001b[0m )\n",
|
||||
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gym_aloha/env.py:166\u001b[0m, in \u001b[0;36mAlohaEnv.reset\u001b[0;34m(self, seed, options)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtask)\n\u001b[0;32m--> 166\u001b[0m raw_obs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_env\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m observation \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_raw_obs(raw_obs\u001b[38;5;241m.\u001b[39mobservation)\n\u001b[1;32m 170\u001b[0m info \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mis_success\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mFalse\u001b[39;00m}\n",
|
||||
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/rl/control.py:89\u001b[0m, in \u001b[0;36mEnvironment.reset\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_physics\u001b[38;5;241m.\u001b[39mreset_context():\n\u001b[1;32m 87\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_task\u001b[38;5;241m.\u001b[39minitialize_episode(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_physics)\n\u001b[0;32m---> 89\u001b[0m observation \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_task\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_observation\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_physics\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 90\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_flat_observation:\n\u001b[1;32m 91\u001b[0m observation \u001b[38;5;241m=\u001b[39m flatten_observation(observation)\n",
|
||||
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gym_aloha/tasks/sim.py:92\u001b[0m, in \u001b[0;36mBimanualViperXTask.get_observation\u001b[0;34m(self, physics)\u001b[0m\n\u001b[1;32m 90\u001b[0m obs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124menv_state\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_env_state(physics)\n\u001b[1;32m 91\u001b[0m obs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimages\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m {}\n\u001b[0;32m---> 92\u001b[0m obs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimages\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtop\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[43mphysics\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrender\u001b[49m\u001b[43m(\u001b[49m\u001b[43mheight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m480\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwidth\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m640\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcamera_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtop\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 93\u001b[0m obs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimages\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mangle\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m physics\u001b[38;5;241m.\u001b[39mrender(height\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m480\u001b[39m, width\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m640\u001b[39m, camera_id\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mangle\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 94\u001b[0m obs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimages\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvis\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m physics\u001b[38;5;241m.\u001b[39mrender(height\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m480\u001b[39m, width\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m640\u001b[39m, camera_id\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfront_close\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
||||
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/mujoco/engine.py:223\u001b[0m, in \u001b[0;36mPhysics.render\u001b[0;34m(self, height, width, camera_id, overlays, depth, segmentation, scene_option, render_flag_overrides, scene_callback)\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mrender\u001b[39m(\n\u001b[1;32m 179\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 180\u001b[0m height\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m240\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;28;01mNone\u001b[39;00m]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 190\u001b[0m ):\n\u001b[1;32m 191\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Returns a camera view as a NumPy array of pixel values.\u001b[39;00m\n\u001b[1;32m 192\u001b[0m \n\u001b[1;32m 193\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 221\u001b[0m \u001b[38;5;124;03m The rendered RGB, depth or segmentation image.\u001b[39;00m\n\u001b[1;32m 222\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 223\u001b[0m camera \u001b[38;5;241m=\u001b[39m \u001b[43mCamera\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 224\u001b[0m \u001b[43m \u001b[49m\u001b[43mphysics\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 225\u001b[0m \u001b[43m \u001b[49m\u001b[43mheight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 226\u001b[0m \u001b[43m \u001b[49m\u001b[43mwidth\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwidth\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 227\u001b[0m \u001b[43m \u001b[49m\u001b[43mcamera_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcamera_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 228\u001b[0m \u001b[43m \u001b[49m\u001b[43mscene_callback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mscene_callback\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 229\u001b[0m image \u001b[38;5;241m=\u001b[39m camera\u001b[38;5;241m.\u001b[39mrender(\n\u001b[1;32m 230\u001b[0m overlays\u001b[38;5;241m=\u001b[39moverlays, depth\u001b[38;5;241m=\u001b[39mdepth, segmentation\u001b[38;5;241m=\u001b[39msegmentation,\n\u001b[1;32m 231\u001b[0m scene_option\u001b[38;5;241m=\u001b[39mscene_option, render_flag_overrides\u001b[38;5;241m=\u001b[39mrender_flag_overrides)\n\u001b[1;32m 232\u001b[0m camera\u001b[38;5;241m.\u001b[39m_scene\u001b[38;5;241m.\u001b[39mfree() \u001b[38;5;66;03m# pylint: disable=protected-access\u001b[39;00m\n",
|
||||
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/mujoco/engine.py:711\u001b[0m, in \u001b[0;36mCamera.__init__\u001b[0;34m(self, physics, height, width, camera_id, max_geom, scene_callback)\u001b[0m\n\u001b[1;32m 708\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_rgb_buffer \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mempty((\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_height, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_width, \u001b[38;5;241m3\u001b[39m), dtype\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39muint8)\n\u001b[1;32m 709\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_depth_buffer \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mempty((\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_height, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_width), dtype\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39mfloat32)\n\u001b[0;32m--> 711\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_physics\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcontexts\u001b[49m\u001b[38;5;241m.\u001b[39mmujoco \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 712\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_physics\u001b[38;5;241m.\u001b[39mcontexts\u001b[38;5;241m.\u001b[39mgl\u001b[38;5;241m.\u001b[39mmake_current() \u001b[38;5;28;01mas\u001b[39;00m ctx:\n\u001b[1;32m 713\u001b[0m ctx\u001b[38;5;241m.\u001b[39mcall(mujoco\u001b[38;5;241m.\u001b[39mmjr_setBuffer, mujoco\u001b[38;5;241m.\u001b[39mmjtFramebuffer\u001b[38;5;241m.\u001b[39mmjFB_OFFSCREEN,\n\u001b[1;32m 714\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_physics\u001b[38;5;241m.\u001b[39mcontexts\u001b[38;5;241m.\u001b[39mmujoco\u001b[38;5;241m.\u001b[39mptr)\n",
|
||||
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/mujoco/engine.py:533\u001b[0m, in \u001b[0;36mPhysics.contexts\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 531\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_contexts_lock:\n\u001b[1;32m 532\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_contexts:\n\u001b[0;32m--> 533\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_make_rendering_contexts\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 534\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_contexts\n",
|
||||
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/mujoco/engine.py:519\u001b[0m, in \u001b[0;36mPhysics._make_rendering_contexts\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 516\u001b[0m render_context \u001b[38;5;241m=\u001b[39m _render\u001b[38;5;241m.\u001b[39mRenderer(\n\u001b[1;32m 517\u001b[0m max_width\u001b[38;5;241m=\u001b[39mmax_width, max_height\u001b[38;5;241m=\u001b[39mmax_height)\n\u001b[1;32m 518\u001b[0m \u001b[38;5;66;03m# Create the MuJoCo context.\u001b[39;00m\n\u001b[0;32m--> 519\u001b[0m mujoco_context \u001b[38;5;241m=\u001b[39m \u001b[43mwrapper\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mMjrContext\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrender_context\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 520\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_contexts \u001b[38;5;241m=\u001b[39m Contexts(gl\u001b[38;5;241m=\u001b[39mrender_context, mujoco\u001b[38;5;241m=\u001b[39mmujoco_context)\n",
|
||||
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/mujoco/wrapper/core.py:603\u001b[0m, in \u001b[0;36mMjrContext.__init__\u001b[0;34m(self, model, gl_context, font_scale)\u001b[0m\n\u001b[1;32m 601\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gl_context \u001b[38;5;241m=\u001b[39m gl_context\n\u001b[1;32m 602\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m gl_context\u001b[38;5;241m.\u001b[39mmake_current() \u001b[38;5;28;01mas\u001b[39;00m ctx:\n\u001b[0;32m--> 603\u001b[0m ptr \u001b[38;5;241m=\u001b[39m \u001b[43mctx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcall\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmujoco\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mMjrContext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mptr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfont_scale\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 604\u001b[0m ctx\u001b[38;5;241m.\u001b[39mcall(mujoco\u001b[38;5;241m.\u001b[39mmjr_setBuffer, mujoco\u001b[38;5;241m.\u001b[39mmjtFramebuffer\u001b[38;5;241m.\u001b[39mmjFB_OFFSCREEN, ptr)\n\u001b[1;32m 605\u001b[0m gl_context\u001b[38;5;241m.\u001b[39mkeep_alive(ptr)\n",
|
||||
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/_render/executor/render_executor.py:138\u001b[0m, in \u001b[0;36mPassthroughRenderExecutor.call\u001b[0;34m(self, func, *args, **kwargs)\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mcall\u001b[39m(\u001b[38;5;28mself\u001b[39m, func, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 137\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_locked()\n\u001b[0;32m--> 138\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"\u001b[0;31mFatalError\u001b[0m: gladLoadGL error"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Reset the policy and environments to prepare for rollout\n",
|
||||
"policy.reset()\n",
|
||||
"numpy_observation, info = env.reset(seed=42)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "lerobot",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
Loading…
Reference in New Issue