diff --git a/output/rollout.mp4 b/output/rollout.mp4 new file mode 100644 index 00000000..d4b24771 Binary files /dev/null and b/output/rollout.mp4 differ diff --git a/tester.ipynb b/tester.ipynb new file mode 100644 index 00000000..c567051e --- /dev/null +++ b/tester.ipynb @@ -0,0 +1,674 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "import gym_pusht # noqa: F401\n", + "import gymnasium as gym\n", + "import imageio\n", + "import numpy\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Select your device\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# Provide the [hugging face repo id](https://huggingface.co/lerobot/diffusion_pusht):\n", + "pretrained_policy_path = \"IliaLarchenko/dot_pusht_keypoints\"" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from lerobot.common.policies.dot.modeling_dot import DOTPolicy\n", + "policy = DOTPolicy.from_pretrained(pretrained_policy_path, map_location=device)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "env = gym.make(\n", + " \"gym_pusht/PushT-v0\",\n", + " obs_type=\"environment_state_agent_pos\",\n", + " max_episode_steps=300,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'observation.state': PolicyFeature(type=, shape=(2,)), 'observation.environment_state': PolicyFeature(type=, shape=(16,))}\n", + "Dict('agent_pos': Box(0.0, 512.0, (2,), float64), 'environment_state': Box(0.0, 512.0, (16,), float64))\n" + ] + } + ], + "source": [ + "print(policy.config.input_features)\n", + "print(env.observation_space)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'action': PolicyFeature(type=, shape=(2,))}\n", + "Box(0.0, 512.0, (2,), float32)\n" + ] + } + ], + "source": [ + "print(policy.config.output_features)\n", + "print(env.action_space)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "policy.reset()\n", + "numpy_observation, info = env.reset(seed=42)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "# Prepare to collect every rewards and all the frames of the episode,\n", + "# from initial state to final state.\n", + "rewards = []\n", + "frames = []\n", + "\n", + "# Render frame of the initial state\n", + "frames.append(env.render())" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step=0 reward=np.float64(0.0) terminated=False\n", + "step=1 reward=np.float64(0.0) terminated=False\n", + "step=2 reward=np.float64(0.0) terminated=False\n", + "step=3 reward=np.float64(0.0) terminated=False\n", + "step=4 reward=np.float64(0.0) terminated=False\n", + "step=5 reward=np.float64(0.0) terminated=False\n", + "step=6 reward=np.float64(0.0) terminated=False\n", + "step=7 reward=np.float64(0.0) terminated=False\n", + "step=8 reward=np.float64(0.0) terminated=False\n", + "step=9 reward=np.float64(0.0) terminated=False\n", + "step=10 reward=np.float64(0.0) terminated=False\n", + "step=11 reward=np.float64(0.0) terminated=False\n", + "step=12 reward=np.float64(0.0) terminated=False\n", + "step=13 reward=np.float64(0.0) terminated=False\n", + "step=14 reward=np.float64(0.0) terminated=False\n", + "step=15 reward=np.float64(0.0) terminated=False\n", + "step=16 reward=np.float64(0.0) terminated=False\n", + "step=17 reward=np.float64(0.0) terminated=False\n", + "step=18 reward=np.float64(0.0) terminated=False\n", + "step=19 reward=np.float64(0.0) terminated=False\n", + "step=20 reward=np.float64(0.0) terminated=False\n", + "step=21 reward=np.float64(0.0) terminated=False\n", + "step=22 reward=np.float64(0.0009941544780861455) terminated=False\n", + "step=23 reward=np.float64(0.033647507519038757) terminated=False\n", + "step=24 reward=np.float64(0.07026086006261555) terminated=False\n", + "step=25 reward=np.float64(0.10069667553409196) terminated=False\n", + "step=26 reward=np.float64(0.11389926069925992) terminated=False\n", + "step=27 reward=np.float64(0.12027077768723497) terminated=False\n", + "step=28 reward=np.float64(0.12486582623684722) terminated=False\n", + "step=29 reward=np.float64(0.12815916861048604) terminated=False\n", + "step=30 reward=np.float64(0.1303391815805222) terminated=False\n", + "step=31 reward=np.float64(0.1315231117258188) terminated=False\n", + "step=32 reward=np.float64(0.13221640549835664) terminated=False\n", + "step=33 reward=np.float64(0.13254763259209015) terminated=False\n", + "step=34 reward=np.float64(0.13263558368425837) terminated=False\n", + "step=35 reward=np.float64(0.13263932937572084) terminated=False\n", + "step=36 reward=np.float64(0.13263932937572084) terminated=False\n", + "step=37 reward=np.float64(0.13263932937572084) terminated=False\n", + "step=38 reward=np.float64(0.13263932937572084) terminated=False\n", + "step=39 reward=np.float64(0.13263932937572084) terminated=False\n", + "step=40 reward=np.float64(0.13263932937572084) terminated=False\n", + "step=41 reward=np.float64(0.13263932937572084) terminated=False\n", + "step=42 reward=np.float64(0.13263932937572084) terminated=False\n", + "step=43 reward=np.float64(0.13263932937572084) terminated=False\n", + "step=44 reward=np.float64(0.13263932937572084) terminated=False\n", + "step=45 reward=np.float64(0.13263932937572084) terminated=False\n", + "step=46 reward=np.float64(0.13263932937572084) terminated=False\n", + "step=47 reward=np.float64(0.13263932937572084) terminated=False\n", + "step=48 reward=np.float64(0.13263932937572084) terminated=False\n", + "step=49 reward=np.float64(0.13263932937572084) terminated=False\n", + "step=50 reward=np.float64(0.13263932937572084) terminated=False\n", + "step=51 reward=np.float64(0.13263932937572084) terminated=False\n", + "step=52 reward=np.float64(0.14872285364307145) terminated=False\n", + "step=53 reward=np.float64(0.19847005044261715) terminated=False\n", + "step=54 reward=np.float64(0.24338272205852812) terminated=False\n", + "step=55 reward=np.float64(0.2667243347061481) terminated=False\n", + "step=56 reward=np.float64(0.2691675276421592) terminated=False\n", + "step=57 reward=np.float64(0.3018254762158707) terminated=False\n", + "step=58 reward=np.float64(0.3613501686331564) terminated=False\n", + "step=59 reward=np.float64(0.4613512243665896) terminated=False\n", + "step=60 reward=np.float64(0.5617656688929643) terminated=False\n", + "step=61 reward=np.float64(0.5747609180351871) terminated=False\n", + "step=62 reward=np.float64(0.507048651118485) terminated=False\n", + "step=63 reward=np.float64(0.44332042287270484) terminated=False\n", + "step=64 reward=np.float64(0.3993804222553378) terminated=False\n", + "step=65 reward=np.float64(0.36941592278664487) terminated=False\n", + "step=66 reward=np.float64(0.36941592278664487) terminated=False\n", + "step=67 reward=np.float64(0.36941592278664487) terminated=False\n", + "step=68 reward=np.float64(0.36941592278664487) terminated=False\n", + "step=69 reward=np.float64(0.36941592278664487) terminated=False\n", + "step=70 reward=np.float64(0.36941592278664487) terminated=False\n", + "step=71 reward=np.float64(0.36941592278664487) terminated=False\n", + "step=72 reward=np.float64(0.36941592278664487) terminated=False\n", + "step=73 reward=np.float64(0.36941592278664487) terminated=False\n", + "step=74 reward=np.float64(0.36941592278664487) terminated=False\n", + "step=75 reward=np.float64(0.4322328474940646) terminated=False\n", + "step=76 reward=np.float64(0.4818152566968738) terminated=False\n", + "step=77 reward=np.float64(0.5252535051167734) terminated=False\n", + "step=78 reward=np.float64(0.5586446249197407) terminated=False\n", + "step=79 reward=np.float64(0.5885022076599307) terminated=False\n", + "step=80 reward=np.float64(0.5977994643852952) terminated=False\n", + "step=81 reward=np.float64(0.597859201570885) terminated=False\n", + "step=82 reward=np.float64(0.597859201570885) terminated=False\n", + "step=83 reward=np.float64(0.597859201570885) terminated=False\n", + "step=84 reward=np.float64(0.597859201570885) terminated=False\n", + "step=85 reward=np.float64(0.597859201570885) terminated=False\n", + "step=86 reward=np.float64(0.597859201570885) terminated=False\n", + "step=87 reward=np.float64(0.597859201570885) terminated=False\n", + "step=88 reward=np.float64(0.597859201570885) terminated=False\n", + "step=89 reward=np.float64(0.6876341127908622) terminated=False\n", + "step=90 reward=np.float64(0.8166289152424572) terminated=False\n", + "step=91 reward=np.float64(0.9421614978354362) terminated=False\n", + "step=92 reward=np.float64(0.9441608976568224) terminated=False\n", + "step=93 reward=np.float64(0.9104163604296966) terminated=False\n", + "step=94 reward=np.float64(0.909182661371769) terminated=False\n", + "step=95 reward=np.float64(0.909182661371769) terminated=False\n", + "step=96 reward=np.float64(0.909182661371769) terminated=False\n", + "step=97 reward=np.float64(0.909182661371769) terminated=False\n", + "step=98 reward=np.float64(0.909182661371769) terminated=False\n", + "step=99 reward=np.float64(0.909182661371769) terminated=False\n", + "step=100 reward=np.float64(0.909182661371769) terminated=False\n", + "step=101 reward=np.float64(0.909182661371769) terminated=False\n", + "step=102 reward=np.float64(0.909182661371769) terminated=False\n", + "step=103 reward=np.float64(0.9340357871805705) terminated=False\n", + "step=104 reward=np.float64(0.8851102121651142) terminated=False\n", + "step=105 reward=np.float64(0.8809768749693764) terminated=False\n", + "step=106 reward=np.float64(0.8809768749693764) terminated=False\n", + "step=107 reward=np.float64(0.8809768749693764) terminated=False\n", + "step=108 reward=np.float64(0.8809768749693764) terminated=False\n", + "step=109 reward=np.float64(0.8809768749693764) terminated=False\n", + "step=110 reward=np.float64(0.8809768749693764) terminated=False\n", + "step=111 reward=np.float64(0.8809768749693764) terminated=False\n", + "step=112 reward=np.float64(0.8809768749693764) terminated=False\n", + "step=113 reward=np.float64(0.8809768749693764) terminated=False\n", + "step=114 reward=np.float64(0.8809768749693764) terminated=False\n", + "step=115 reward=np.float64(0.8809768749693764) terminated=False\n", + "step=116 reward=np.float64(0.8809768749693764) terminated=False\n", + "step=117 reward=np.float64(0.8809768749693764) terminated=False\n", + "step=118 reward=np.float64(0.8809768749693764) terminated=False\n", + "step=119 reward=np.float64(0.9518089158714292) terminated=False\n", + "step=120 reward=np.float64(0.9405458729516311) terminated=False\n", + "step=121 reward=np.float64(0.935511214687435) terminated=False\n", + "step=122 reward=np.float64(0.935511214687435) terminated=False\n", + "step=123 reward=np.float64(0.935511214687435) terminated=False\n", + "step=124 reward=np.float64(0.935511214687435) terminated=False\n", + "step=125 reward=np.float64(0.935511214687435) terminated=False\n", + "step=126 reward=np.float64(0.935511214687435) terminated=False\n", + "step=127 reward=np.float64(0.935511214687435) terminated=False\n", + "step=128 reward=np.float64(0.935511214687435) terminated=False\n", + "step=129 reward=np.float64(0.935511214687435) terminated=False\n", + "step=130 reward=np.float64(0.935511214687435) terminated=False\n", + "step=131 reward=np.float64(0.935511214687435) terminated=False\n", + "step=132 reward=np.float64(0.935511214687435) terminated=False\n", + "step=133 reward=np.float64(0.935511214687435) terminated=False\n", + "step=134 reward=np.float64(0.935511214687435) terminated=False\n", + "step=135 reward=np.float64(0.9534990217822209) terminated=False\n", + "step=136 reward=np.float64(0.9596585109597399) terminated=False\n", + "step=137 reward=np.float64(0.882875733420291) terminated=False\n", + "step=138 reward=np.float64(0.8277880190838034) terminated=False\n", + "step=139 reward=np.float64(0.8211529871155266) terminated=False\n", + "step=140 reward=np.float64(0.8211529871155266) terminated=False\n", + "step=141 reward=np.float64(0.8211529871155266) terminated=False\n", + "step=142 reward=np.float64(0.8211529871155266) terminated=False\n", + "step=143 reward=np.float64(0.8211529871155266) terminated=False\n", + "step=144 reward=np.float64(0.8211529871155266) terminated=False\n", + "step=145 reward=np.float64(0.8211529871155266) terminated=False\n", + "step=146 reward=np.float64(0.8211529871155266) terminated=False\n", + "step=147 reward=np.float64(0.8211529871155266) terminated=False\n", + "step=148 reward=np.float64(0.8211529871155266) terminated=False\n", + "step=149 reward=np.float64(0.8211529871155266) terminated=False\n", + "step=150 reward=np.float64(0.8211529871155266) terminated=False\n", + "step=151 reward=np.float64(0.8211529871155266) terminated=False\n", + "step=152 reward=np.float64(0.8211529871155266) terminated=False\n", + "step=153 reward=np.float64(0.8211529871155266) terminated=False\n", + "step=154 reward=np.float64(0.856408395120982) terminated=False\n", + "step=155 reward=np.float64(0.9304040416833055) terminated=False\n", + "step=156 reward=np.float64(0.9770812279113622) terminated=False\n", + "step=157 reward=np.float64(0.9944597232685968) terminated=False\n", + "step=158 reward=np.float64(0.9944597232685968) terminated=False\n", + "step=159 reward=np.float64(0.9944597232685968) terminated=False\n", + "step=160 reward=np.float64(0.9944597232685968) terminated=False\n", + "step=161 reward=np.float64(0.9944597232685968) terminated=False\n", + "step=162 reward=np.float64(0.9944597232685968) terminated=False\n", + "step=163 reward=np.float64(0.9944597232685968) terminated=False\n", + "step=164 reward=np.float64(0.9944597232685968) terminated=False\n", + "step=165 reward=np.float64(0.9944597232685968) terminated=False\n", + "step=166 reward=np.float64(0.9944597232685968) terminated=False\n", + "step=167 reward=np.float64(0.9944597232685968) terminated=False\n", + "step=168 reward=np.float64(0.9944597232685968) terminated=False\n", + "step=169 reward=np.float64(0.9820541217675358) terminated=False\n", + "step=170 reward=np.float64(0.8765528119646949) terminated=False\n", + "step=171 reward=np.float64(0.8231919366320396) terminated=False\n", + "step=172 reward=np.float64(0.7926155231821123) terminated=False\n", + "step=173 reward=np.float64(0.7902960563492054) terminated=False\n", + "step=174 reward=np.float64(0.7902960563492054) terminated=False\n", + "step=175 reward=np.float64(0.7902960563492054) terminated=False\n", + "step=176 reward=np.float64(0.7902960563492054) terminated=False\n", + "step=177 reward=np.float64(0.7902960563492054) terminated=False\n", + "step=178 reward=np.float64(0.7902960563492054) terminated=False\n", + "step=179 reward=np.float64(0.7902960563492054) terminated=False\n", + "step=180 reward=np.float64(0.7902960563492054) terminated=False\n", + "step=181 reward=np.float64(0.8158199658870418) terminated=False\n", + "step=182 reward=np.float64(0.8191627090126786) terminated=False\n", + "step=183 reward=np.float64(0.8172224839948001) terminated=False\n", + "step=184 reward=np.float64(0.8172224839948001) terminated=False\n", + "step=185 reward=np.float64(0.8172224839948001) terminated=False\n", + "step=186 reward=np.float64(0.8172224839948001) terminated=False\n", + "step=187 reward=np.float64(0.8172224839948001) terminated=False\n", + "step=188 reward=np.float64(0.8172224839948001) terminated=False\n", + "step=189 reward=np.float64(0.8172224839948001) terminated=False\n", + "step=190 reward=np.float64(0.8172224839948001) terminated=False\n", + "step=191 reward=np.float64(0.8172224839948001) terminated=False\n", + "step=192 reward=np.float64(0.8172224839948001) terminated=False\n", + "step=193 reward=np.float64(0.8172224839948001) terminated=False\n", + "step=194 reward=np.float64(0.8172224839948001) terminated=False\n", + "step=195 reward=np.float64(0.8172224839948001) terminated=False\n", + "step=196 reward=np.float64(0.8172224839948001) terminated=False\n", + "step=197 reward=np.float64(0.878735078350138) terminated=False\n", + "step=198 reward=np.float64(0.8564816396314117) terminated=False\n", + "step=199 reward=np.float64(0.7970005244772627) terminated=False\n", + "step=200 reward=np.float64(0.7688729860960439) terminated=False\n", + "step=201 reward=np.float64(0.7671148163486476) terminated=False\n", + "step=202 reward=np.float64(0.7671148163486476) terminated=False\n", + "step=203 reward=np.float64(0.7671148163486476) terminated=False\n", + "step=204 reward=np.float64(0.7671148163486476) terminated=False\n", + "step=205 reward=np.float64(0.7671148163486476) terminated=False\n", + "step=206 reward=np.float64(0.7671148163486476) terminated=False\n", + "step=207 reward=np.float64(0.7671148163486476) terminated=False\n", + "step=208 reward=np.float64(0.7671148163486476) terminated=False\n", + "step=209 reward=np.float64(0.7671148163486476) terminated=False\n", + "step=210 reward=np.float64(0.7671148163486476) terminated=False\n", + "step=211 reward=np.float64(0.7671148163486476) terminated=False\n", + "step=212 reward=np.float64(0.7671148163486476) terminated=False\n", + "step=213 reward=np.float64(0.7671148163486476) terminated=False\n", + "step=214 reward=np.float64(0.7671148163486476) terminated=False\n", + "step=215 reward=np.float64(0.7671148163486476) terminated=False\n", + "step=216 reward=np.float64(0.7671148163486476) terminated=False\n", + "step=217 reward=np.float64(0.8045352949993082) terminated=False\n", + "step=218 reward=np.float64(0.8328184612705187) terminated=False\n", + "step=219 reward=np.float64(0.8558996801195216) terminated=False\n", + "step=220 reward=np.float64(0.8576887923798919) terminated=False\n", + "step=221 reward=np.float64(0.8576887923798919) terminated=False\n", + "step=222 reward=np.float64(0.8576887923798919) terminated=False\n", + "step=223 reward=np.float64(0.8576887923798919) terminated=False\n", + "step=224 reward=np.float64(0.8576887923798919) terminated=False\n", + "step=225 reward=np.float64(0.8576887923798919) terminated=False\n", + "step=226 reward=np.float64(0.8576887923798919) terminated=False\n", + "step=227 reward=np.float64(0.8576887923798919) terminated=False\n", + "step=228 reward=np.float64(0.8576887923798919) terminated=False\n", + "step=229 reward=np.float64(0.8576887923798919) terminated=False\n", + "step=230 reward=np.float64(0.8576887923798919) terminated=False\n", + "step=231 reward=np.float64(0.8576887923798919) terminated=False\n", + "step=232 reward=np.float64(0.8576887923798919) terminated=False\n", + "step=233 reward=np.float64(0.8576887923798919) terminated=False\n", + "step=234 reward=np.float64(0.935454295639811) terminated=False\n", + "step=235 reward=np.float64(0.9874094853870982) terminated=False\n", + "step=236 reward=np.float64(0.9719847917595543) terminated=False\n", + "step=237 reward=np.float64(0.9719847917595543) terminated=False\n", + "step=238 reward=np.float64(0.9719847917595543) terminated=False\n", + "step=239 reward=np.float64(0.9719847917595543) terminated=False\n", + "step=240 reward=np.float64(0.9719847917595543) terminated=False\n", + "step=241 reward=np.float64(0.9719847917595543) terminated=False\n", + "step=242 reward=np.float64(0.9719847917595543) terminated=False\n", + "step=243 reward=np.float64(0.9719847917595543) terminated=False\n", + "step=244 reward=np.float64(0.9719847917595543) terminated=False\n", + "step=245 reward=np.float64(0.9719847917595543) terminated=False\n", + "step=246 reward=np.float64(0.9719847917595543) terminated=False\n", + "step=247 reward=np.float64(0.9719847917595543) terminated=False\n", + "step=248 reward=np.float64(0.9719847917595543) terminated=False\n", + "step=249 reward=np.float64(0.9719847917595543) terminated=False\n", + "step=250 reward=np.float64(0.9719847917595543) terminated=False\n", + "step=251 reward=np.float64(0.9719847917595543) terminated=False\n", + "step=252 reward=np.float64(0.9719847917595543) terminated=False\n", + "step=253 reward=np.float64(0.9719847917595543) terminated=False\n", + "step=254 reward=np.float64(0.9727790631955697) terminated=False\n", + "step=255 reward=np.float64(0.946125141646605) terminated=False\n", + "step=256 reward=np.float64(0.9368755165399575) terminated=False\n", + "step=257 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=258 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=259 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=260 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=261 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=262 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=263 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=264 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=265 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=266 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=267 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=268 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=269 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=270 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=271 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=272 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=273 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=274 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=275 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=276 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=277 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=278 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=279 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=280 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=281 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=282 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=283 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=284 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=285 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=286 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=287 reward=np.float64(0.9360986686274937) terminated=False\n", + "step=288 reward=np.float64(0.9316550466986755) terminated=False\n", + "step=289 reward=np.float64(0.9218676877631473) terminated=False\n", + "step=290 reward=np.float64(0.9213441513220694) terminated=False\n", + "step=291 reward=np.float64(0.9213441513220694) terminated=False\n", + "step=292 reward=np.float64(0.9213441513220694) terminated=False\n", + "step=293 reward=np.float64(0.9213441513220694) terminated=False\n", + "step=294 reward=np.float64(0.9213441513220694) terminated=False\n", + "step=295 reward=np.float64(0.9213441513220694) terminated=False\n", + "step=296 reward=np.float64(0.9213441513220694) terminated=False\n", + "step=297 reward=np.float64(0.9213441513220694) terminated=False\n", + "step=298 reward=np.float64(0.9213441513220694) terminated=False\n", + "step=299 reward=np.float64(0.9213441513220694) terminated=False\n" + ] + } + ], + "source": [ + "step = 0\n", + "done = False\n", + "\n", + "while not done:\n", + " # Prepare observation for the policy\n", + " state = torch.from_numpy(numpy_observation[\"agent_pos\"]) # Agent position\n", + " env_state = torch.from_numpy(numpy_observation[\"environment_state\"]) # Environment state\n", + "\n", + " # Convert to float32\n", + " state = state.to(torch.float32)\n", + " env_state = env_state.to(torch.float32)\n", + "\n", + " # Send data tensors from CPU to GPU\n", + " state = state.to(device, non_blocking=True)\n", + " env_state = env_state.to(device, non_blocking=True)\n", + "\n", + " # Add extra (empty) batch dimension, required to forward the policy\n", + " state = state.unsqueeze(0)\n", + " env_state = env_state.unsqueeze(0)\n", + "\n", + " # Create the policy input dictionary\n", + " observation = {\n", + " \"observation.state\": state,\n", + " \"observation.environment_state\": env_state, # Add environment_state here\n", + " }\n", + "\n", + " # Predict the next action with respect to the current observation\n", + " with torch.inference_mode():\n", + " action = policy.select_action(observation)\n", + "\n", + " # Prepare the action for the environment\n", + " numpy_action = action.squeeze(0).to(\"cpu\").numpy()\n", + "\n", + " # Step through the environment and receive a new observation\n", + " numpy_observation, reward, terminated, truncated, info = env.step(numpy_action)\n", + " print(f\"{step=} {reward=} {terminated=}\")\n", + "\n", + " # Keep track of all the rewards and frames\n", + " rewards.append(reward)\n", + " frames.append(env.render())\n", + "\n", + " # The rollout is considered done when the success state is reached (i.e. terminated is True),\n", + " # or the maximum number of iterations is reached (i.e. truncated is True)\n", + " done = terminated or truncated or done\n", + " step += 1\n" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Failure!\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (680, 680) to (688, 688) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to 1 (risking incompatibility).\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Video of the evaluation is available in '/home/lerobot/output/rollout.mp4'.\n" + ] + } + ], + "source": [ + "if terminated:\n", + " print(\"Success!\")\n", + "else:\n", + " print(\"Failure!\")\n", + "\n", + "# Get the speed of environment (i.e. its number of frames per second).\n", + "fps = env.metadata[\"render_fps\"]\n", + "\n", + "# Encode all frames into a mp4 video.\n", + "video_path = \"/home/lerobot/output/rollout.mp4\"\n", + "imageio.mimsave(str(video_path), numpy.stack(frames), fps=fps)\n", + "\n", + "print(f\"Video of the evaluation is available in '{video_path}'.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "#now on aloha\n", + "import imageio\n", + "import gymnasium as gym\n", + "import numpy as np\n", + "import gym_aloha\n", + "env = gym.make(\n", + " \"gym_aloha/AlohaInsertion-v0\",\n", + " obs_type=\"pixels\",\n", + " max_episode_steps=300,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/envs/lerobot/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "Downloading: \"https://download.pytorch.org/models/resnet18-f37072fd.pth\" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth\n", + "100%|██████████| 44.7M/44.7M [00:00<00:00, 172MB/s]\n" + ] + } + ], + "source": [ + "from lerobot.common.policies.dot.modeling_dot import DOTPolicy\n", + "pretrained_policy_path = \"IliaLarchenko/dot_bimanual_insert\"\n", + "policy = DOTPolicy.from_pretrained(pretrained_policy_path, map_location=device)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'observation.images.top': PolicyFeature(type=, shape=(3, 480, 640)), 'observation.state': PolicyFeature(type=, shape=(14,))}\n", + "Dict('top': Box(0, 255, (480, 640, 3), uint8))\n" + ] + } + ], + "source": [ + "# We can verify that the shapes of the features expected by the policy match the ones from the observations\n", + "# produced by the environment\n", + "print(policy.config.input_features)\n", + "print(env.observation_space)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'action': PolicyFeature(type=, shape=(14,))}\n", + "Box(-1.0, 1.0, (14,), float32)\n" + ] + } + ], + "source": [ + "# Similarly, we can check that the actions produced by the policy will match the actions expected by the\n", + "# environment\n", + "print(policy.config.output_features)\n", + "print(env.action_space)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "ename": "FatalError", + "evalue": "gladLoadGL error", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mFatalError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[16], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Reset the policy and environments to prepare for rollout\u001b[39;00m\n\u001b[1;32m 2\u001b[0m policy\u001b[38;5;241m.\u001b[39mreset()\n\u001b[0;32m----> 3\u001b[0m numpy_observation, info \u001b[38;5;241m=\u001b[39m \u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m(\u001b[49m\u001b[43mseed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m42\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gymnasium/wrappers/time_limit.py:75\u001b[0m, in \u001b[0;36mTimeLimit.reset\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 66\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero.\u001b[39;00m\n\u001b[1;32m 67\u001b[0m \n\u001b[1;32m 68\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;124;03m The reset environment\u001b[39;00m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_elapsed_steps \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m---> 75\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gymnasium/wrappers/order_enforcing.py:61\u001b[0m, in \u001b[0;36mOrderEnforcing.reset\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Resets the environment with `kwargs`.\"\"\"\u001b[39;00m\n\u001b[1;32m 60\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_has_reset \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m---> 61\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gymnasium/wrappers/env_checker.py:57\u001b[0m, in \u001b[0;36mPassiveEnvChecker.reset\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchecked_reset:\n\u001b[1;32m 56\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchecked_reset \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m---> 57\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43menv_reset_passive_checker\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 58\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 59\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39menv\u001b[38;5;241m.\u001b[39mreset(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", + "File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gymnasium/utils/passive_env_checker.py:186\u001b[0m, in \u001b[0;36menv_reset_passive_checker\u001b[0;34m(env, **kwargs)\u001b[0m\n\u001b[1;32m 181\u001b[0m logger\u001b[38;5;241m.\u001b[39mdeprecation(\n\u001b[1;32m 182\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCurrent gymnasium version requires that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 183\u001b[0m )\n\u001b[1;32m 185\u001b[0m \u001b[38;5;66;03m# Checks the result of env.reset with kwargs\u001b[39;00m\n\u001b[0;32m--> 186\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(result, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[1;32m 189\u001b[0m logger\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[1;32m 190\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(result)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m`\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 191\u001b[0m )\n", + "File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gym_aloha/env.py:166\u001b[0m, in \u001b[0;36mAlohaEnv.reset\u001b[0;34m(self, seed, options)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtask)\n\u001b[0;32m--> 166\u001b[0m raw_obs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_env\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m observation \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_raw_obs(raw_obs\u001b[38;5;241m.\u001b[39mobservation)\n\u001b[1;32m 170\u001b[0m info \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mis_success\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mFalse\u001b[39;00m}\n", + "File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/rl/control.py:89\u001b[0m, in \u001b[0;36mEnvironment.reset\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_physics\u001b[38;5;241m.\u001b[39mreset_context():\n\u001b[1;32m 87\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_task\u001b[38;5;241m.\u001b[39minitialize_episode(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_physics)\n\u001b[0;32m---> 89\u001b[0m observation \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_task\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_observation\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_physics\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 90\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_flat_observation:\n\u001b[1;32m 91\u001b[0m observation \u001b[38;5;241m=\u001b[39m flatten_observation(observation)\n", + "File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gym_aloha/tasks/sim.py:92\u001b[0m, in \u001b[0;36mBimanualViperXTask.get_observation\u001b[0;34m(self, physics)\u001b[0m\n\u001b[1;32m 90\u001b[0m obs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124menv_state\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_env_state(physics)\n\u001b[1;32m 91\u001b[0m obs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimages\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m {}\n\u001b[0;32m---> 92\u001b[0m obs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimages\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtop\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[43mphysics\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrender\u001b[49m\u001b[43m(\u001b[49m\u001b[43mheight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m480\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwidth\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m640\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcamera_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtop\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 93\u001b[0m obs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimages\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mangle\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m physics\u001b[38;5;241m.\u001b[39mrender(height\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m480\u001b[39m, width\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m640\u001b[39m, camera_id\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mangle\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 94\u001b[0m obs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimages\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvis\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m physics\u001b[38;5;241m.\u001b[39mrender(height\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m480\u001b[39m, width\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m640\u001b[39m, camera_id\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfront_close\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/mujoco/engine.py:223\u001b[0m, in \u001b[0;36mPhysics.render\u001b[0;34m(self, height, width, camera_id, overlays, depth, segmentation, scene_option, render_flag_overrides, scene_callback)\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mrender\u001b[39m(\n\u001b[1;32m 179\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 180\u001b[0m height\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m240\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;28;01mNone\u001b[39;00m]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 190\u001b[0m ):\n\u001b[1;32m 191\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Returns a camera view as a NumPy array of pixel values.\u001b[39;00m\n\u001b[1;32m 192\u001b[0m \n\u001b[1;32m 193\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 221\u001b[0m \u001b[38;5;124;03m The rendered RGB, depth or segmentation image.\u001b[39;00m\n\u001b[1;32m 222\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 223\u001b[0m camera \u001b[38;5;241m=\u001b[39m \u001b[43mCamera\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 224\u001b[0m \u001b[43m \u001b[49m\u001b[43mphysics\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 225\u001b[0m \u001b[43m \u001b[49m\u001b[43mheight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 226\u001b[0m \u001b[43m \u001b[49m\u001b[43mwidth\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwidth\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 227\u001b[0m \u001b[43m \u001b[49m\u001b[43mcamera_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcamera_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 228\u001b[0m \u001b[43m \u001b[49m\u001b[43mscene_callback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mscene_callback\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 229\u001b[0m image \u001b[38;5;241m=\u001b[39m camera\u001b[38;5;241m.\u001b[39mrender(\n\u001b[1;32m 230\u001b[0m overlays\u001b[38;5;241m=\u001b[39moverlays, depth\u001b[38;5;241m=\u001b[39mdepth, segmentation\u001b[38;5;241m=\u001b[39msegmentation,\n\u001b[1;32m 231\u001b[0m scene_option\u001b[38;5;241m=\u001b[39mscene_option, render_flag_overrides\u001b[38;5;241m=\u001b[39mrender_flag_overrides)\n\u001b[1;32m 232\u001b[0m camera\u001b[38;5;241m.\u001b[39m_scene\u001b[38;5;241m.\u001b[39mfree() \u001b[38;5;66;03m# pylint: disable=protected-access\u001b[39;00m\n", + "File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/mujoco/engine.py:711\u001b[0m, in \u001b[0;36mCamera.__init__\u001b[0;34m(self, physics, height, width, camera_id, max_geom, scene_callback)\u001b[0m\n\u001b[1;32m 708\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_rgb_buffer \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mempty((\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_height, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_width, \u001b[38;5;241m3\u001b[39m), dtype\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39muint8)\n\u001b[1;32m 709\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_depth_buffer \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mempty((\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_height, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_width), dtype\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39mfloat32)\n\u001b[0;32m--> 711\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_physics\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcontexts\u001b[49m\u001b[38;5;241m.\u001b[39mmujoco \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 712\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_physics\u001b[38;5;241m.\u001b[39mcontexts\u001b[38;5;241m.\u001b[39mgl\u001b[38;5;241m.\u001b[39mmake_current() \u001b[38;5;28;01mas\u001b[39;00m ctx:\n\u001b[1;32m 713\u001b[0m ctx\u001b[38;5;241m.\u001b[39mcall(mujoco\u001b[38;5;241m.\u001b[39mmjr_setBuffer, mujoco\u001b[38;5;241m.\u001b[39mmjtFramebuffer\u001b[38;5;241m.\u001b[39mmjFB_OFFSCREEN,\n\u001b[1;32m 714\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_physics\u001b[38;5;241m.\u001b[39mcontexts\u001b[38;5;241m.\u001b[39mmujoco\u001b[38;5;241m.\u001b[39mptr)\n", + "File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/mujoco/engine.py:533\u001b[0m, in \u001b[0;36mPhysics.contexts\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 531\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_contexts_lock:\n\u001b[1;32m 532\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_contexts:\n\u001b[0;32m--> 533\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_make_rendering_contexts\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 534\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_contexts\n", + "File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/mujoco/engine.py:519\u001b[0m, in \u001b[0;36mPhysics._make_rendering_contexts\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 516\u001b[0m render_context \u001b[38;5;241m=\u001b[39m _render\u001b[38;5;241m.\u001b[39mRenderer(\n\u001b[1;32m 517\u001b[0m max_width\u001b[38;5;241m=\u001b[39mmax_width, max_height\u001b[38;5;241m=\u001b[39mmax_height)\n\u001b[1;32m 518\u001b[0m \u001b[38;5;66;03m# Create the MuJoCo context.\u001b[39;00m\n\u001b[0;32m--> 519\u001b[0m mujoco_context \u001b[38;5;241m=\u001b[39m \u001b[43mwrapper\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mMjrContext\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrender_context\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 520\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_contexts \u001b[38;5;241m=\u001b[39m Contexts(gl\u001b[38;5;241m=\u001b[39mrender_context, mujoco\u001b[38;5;241m=\u001b[39mmujoco_context)\n", + "File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/mujoco/wrapper/core.py:603\u001b[0m, in \u001b[0;36mMjrContext.__init__\u001b[0;34m(self, model, gl_context, font_scale)\u001b[0m\n\u001b[1;32m 601\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gl_context \u001b[38;5;241m=\u001b[39m gl_context\n\u001b[1;32m 602\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m gl_context\u001b[38;5;241m.\u001b[39mmake_current() \u001b[38;5;28;01mas\u001b[39;00m ctx:\n\u001b[0;32m--> 603\u001b[0m ptr \u001b[38;5;241m=\u001b[39m \u001b[43mctx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcall\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmujoco\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mMjrContext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mptr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfont_scale\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 604\u001b[0m ctx\u001b[38;5;241m.\u001b[39mcall(mujoco\u001b[38;5;241m.\u001b[39mmjr_setBuffer, mujoco\u001b[38;5;241m.\u001b[39mmjtFramebuffer\u001b[38;5;241m.\u001b[39mmjFB_OFFSCREEN, ptr)\n\u001b[1;32m 605\u001b[0m gl_context\u001b[38;5;241m.\u001b[39mkeep_alive(ptr)\n", + "File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/_render/executor/render_executor.py:138\u001b[0m, in \u001b[0;36mPassthroughRenderExecutor.call\u001b[0;34m(self, func, *args, **kwargs)\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mcall\u001b[39m(\u001b[38;5;28mself\u001b[39m, func, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 137\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_locked()\n\u001b[0;32m--> 138\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mFatalError\u001b[0m: gladLoadGL error" + ] + } + ], + "source": [ + "# Reset the policy and environments to prepare for rollout\n", + "policy.reset()\n", + "numpy_observation, info = env.reset(seed=42)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "lerobot", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}