{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "import gym_pusht # noqa: F401\n", "import gymnasium as gym\n", "import imageio\n", "import numpy\n", "import torch" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Select your device\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Provide the [hugging face repo id](https://huggingface.co/lerobot/diffusion_pusht):\n", "pretrained_policy_path = \"IliaLarchenko/dot_pusht_keypoints\"" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "from lerobot.common.policies.dot.modeling_dot import DOTPolicy\n", "policy = DOTPolicy.from_pretrained(pretrained_policy_path, map_location=device)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "env = gym.make(\n", " \"gym_pusht/PushT-v0\",\n", " obs_type=\"environment_state_agent_pos\",\n", " max_episode_steps=300,\n", ")" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'observation.state': PolicyFeature(type=, shape=(2,)), 'observation.environment_state': PolicyFeature(type=, shape=(16,))}\n", "Dict('agent_pos': Box(0.0, 512.0, (2,), float64), 'environment_state': Box(0.0, 512.0, (16,), float64))\n" ] } ], "source": [ "print(policy.config.input_features)\n", "print(env.observation_space)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'action': PolicyFeature(type=, shape=(2,))}\n", "Box(0.0, 512.0, (2,), float32)\n" ] } ], "source": [ "print(policy.config.output_features)\n", "print(env.action_space)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "policy.reset()\n", "numpy_observation, info = env.reset(seed=42)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# Prepare to collect every rewards and all the frames of the episode,\n", "# from initial state to final state.\n", "rewards = []\n", "frames = []\n", "\n", "# Render frame of the initial state\n", "frames.append(env.render())" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "step=0 reward=np.float64(0.0) terminated=False\n", "step=1 reward=np.float64(0.0) terminated=False\n", "step=2 reward=np.float64(0.0) terminated=False\n", "step=3 reward=np.float64(0.0) terminated=False\n", "step=4 reward=np.float64(0.0) terminated=False\n", "step=5 reward=np.float64(0.0) terminated=False\n", "step=6 reward=np.float64(0.0) terminated=False\n", "step=7 reward=np.float64(0.0) terminated=False\n", "step=8 reward=np.float64(0.0) terminated=False\n", "step=9 reward=np.float64(0.0) terminated=False\n", "step=10 reward=np.float64(0.0) terminated=False\n", "step=11 reward=np.float64(0.0) terminated=False\n", "step=12 reward=np.float64(0.0) terminated=False\n", "step=13 reward=np.float64(0.0) terminated=False\n", "step=14 reward=np.float64(0.0) terminated=False\n", "step=15 reward=np.float64(0.0) terminated=False\n", "step=16 reward=np.float64(0.0) terminated=False\n", "step=17 reward=np.float64(0.0) terminated=False\n", "step=18 reward=np.float64(0.0) terminated=False\n", "step=19 reward=np.float64(0.0) terminated=False\n", "step=20 reward=np.float64(0.0) terminated=False\n", "step=21 reward=np.float64(0.0) terminated=False\n", "step=22 reward=np.float64(0.0009941544780861455) terminated=False\n", "step=23 reward=np.float64(0.033647507519038757) terminated=False\n", "step=24 reward=np.float64(0.07026086006261555) terminated=False\n", "step=25 reward=np.float64(0.10069667553409196) terminated=False\n", "step=26 reward=np.float64(0.11389926069925992) terminated=False\n", "step=27 reward=np.float64(0.12027077768723497) terminated=False\n", "step=28 reward=np.float64(0.12486582623684722) terminated=False\n", "step=29 reward=np.float64(0.12815916861048604) terminated=False\n", "step=30 reward=np.float64(0.1303391815805222) terminated=False\n", "step=31 reward=np.float64(0.1315231117258188) terminated=False\n", "step=32 reward=np.float64(0.13221640549835664) terminated=False\n", "step=33 reward=np.float64(0.13254763259209015) terminated=False\n", "step=34 reward=np.float64(0.13263558368425837) terminated=False\n", "step=35 reward=np.float64(0.13263932937572084) terminated=False\n", "step=36 reward=np.float64(0.13263932937572084) terminated=False\n", "step=37 reward=np.float64(0.13263932937572084) terminated=False\n", "step=38 reward=np.float64(0.13263932937572084) terminated=False\n", "step=39 reward=np.float64(0.13263932937572084) terminated=False\n", "step=40 reward=np.float64(0.13263932937572084) terminated=False\n", "step=41 reward=np.float64(0.13263932937572084) terminated=False\n", "step=42 reward=np.float64(0.13263932937572084) terminated=False\n", "step=43 reward=np.float64(0.13263932937572084) terminated=False\n", "step=44 reward=np.float64(0.13263932937572084) terminated=False\n", "step=45 reward=np.float64(0.13263932937572084) terminated=False\n", "step=46 reward=np.float64(0.13263932937572084) terminated=False\n", "step=47 reward=np.float64(0.13263932937572084) terminated=False\n", "step=48 reward=np.float64(0.13263932937572084) terminated=False\n", "step=49 reward=np.float64(0.13263932937572084) terminated=False\n", "step=50 reward=np.float64(0.13263932937572084) terminated=False\n", "step=51 reward=np.float64(0.13263932937572084) terminated=False\n", "step=52 reward=np.float64(0.14872285364307145) terminated=False\n", "step=53 reward=np.float64(0.19847005044261715) terminated=False\n", "step=54 reward=np.float64(0.24338272205852812) terminated=False\n", "step=55 reward=np.float64(0.2667243347061481) terminated=False\n", "step=56 reward=np.float64(0.2691675276421592) terminated=False\n", "step=57 reward=np.float64(0.3018254762158707) terminated=False\n", "step=58 reward=np.float64(0.3613501686331564) terminated=False\n", "step=59 reward=np.float64(0.4613512243665896) terminated=False\n", "step=60 reward=np.float64(0.5617656688929643) terminated=False\n", "step=61 reward=np.float64(0.5747609180351871) terminated=False\n", "step=62 reward=np.float64(0.507048651118485) terminated=False\n", "step=63 reward=np.float64(0.44332042287270484) terminated=False\n", "step=64 reward=np.float64(0.3993804222553378) terminated=False\n", "step=65 reward=np.float64(0.36941592278664487) terminated=False\n", "step=66 reward=np.float64(0.36941592278664487) terminated=False\n", "step=67 reward=np.float64(0.36941592278664487) terminated=False\n", "step=68 reward=np.float64(0.36941592278664487) terminated=False\n", "step=69 reward=np.float64(0.36941592278664487) terminated=False\n", "step=70 reward=np.float64(0.36941592278664487) terminated=False\n", "step=71 reward=np.float64(0.36941592278664487) terminated=False\n", "step=72 reward=np.float64(0.36941592278664487) terminated=False\n", "step=73 reward=np.float64(0.36941592278664487) terminated=False\n", "step=74 reward=np.float64(0.36941592278664487) terminated=False\n", "step=75 reward=np.float64(0.4322328474940646) terminated=False\n", "step=76 reward=np.float64(0.4818152566968738) terminated=False\n", "step=77 reward=np.float64(0.5252535051167734) terminated=False\n", "step=78 reward=np.float64(0.5586446249197407) terminated=False\n", "step=79 reward=np.float64(0.5885022076599307) terminated=False\n", "step=80 reward=np.float64(0.5977994643852952) terminated=False\n", "step=81 reward=np.float64(0.597859201570885) terminated=False\n", "step=82 reward=np.float64(0.597859201570885) terminated=False\n", "step=83 reward=np.float64(0.597859201570885) terminated=False\n", "step=84 reward=np.float64(0.597859201570885) terminated=False\n", "step=85 reward=np.float64(0.597859201570885) terminated=False\n", "step=86 reward=np.float64(0.597859201570885) terminated=False\n", "step=87 reward=np.float64(0.597859201570885) terminated=False\n", "step=88 reward=np.float64(0.597859201570885) terminated=False\n", "step=89 reward=np.float64(0.6876341127908622) terminated=False\n", "step=90 reward=np.float64(0.8166289152424572) terminated=False\n", "step=91 reward=np.float64(0.9421614978354362) terminated=False\n", "step=92 reward=np.float64(0.9441608976568224) terminated=False\n", "step=93 reward=np.float64(0.9104163604296966) terminated=False\n", "step=94 reward=np.float64(0.909182661371769) terminated=False\n", "step=95 reward=np.float64(0.909182661371769) terminated=False\n", "step=96 reward=np.float64(0.909182661371769) terminated=False\n", "step=97 reward=np.float64(0.909182661371769) terminated=False\n", "step=98 reward=np.float64(0.909182661371769) terminated=False\n", "step=99 reward=np.float64(0.909182661371769) terminated=False\n", "step=100 reward=np.float64(0.909182661371769) terminated=False\n", "step=101 reward=np.float64(0.909182661371769) terminated=False\n", "step=102 reward=np.float64(0.909182661371769) terminated=False\n", "step=103 reward=np.float64(0.9340357871805705) terminated=False\n", "step=104 reward=np.float64(0.8851102121651142) terminated=False\n", "step=105 reward=np.float64(0.8809768749693764) terminated=False\n", "step=106 reward=np.float64(0.8809768749693764) terminated=False\n", "step=107 reward=np.float64(0.8809768749693764) terminated=False\n", "step=108 reward=np.float64(0.8809768749693764) terminated=False\n", "step=109 reward=np.float64(0.8809768749693764) terminated=False\n", "step=110 reward=np.float64(0.8809768749693764) terminated=False\n", "step=111 reward=np.float64(0.8809768749693764) terminated=False\n", "step=112 reward=np.float64(0.8809768749693764) terminated=False\n", "step=113 reward=np.float64(0.8809768749693764) terminated=False\n", "step=114 reward=np.float64(0.8809768749693764) terminated=False\n", "step=115 reward=np.float64(0.8809768749693764) terminated=False\n", "step=116 reward=np.float64(0.8809768749693764) terminated=False\n", "step=117 reward=np.float64(0.8809768749693764) terminated=False\n", "step=118 reward=np.float64(0.8809768749693764) terminated=False\n", "step=119 reward=np.float64(0.9518089158714292) terminated=False\n", "step=120 reward=np.float64(0.9405458729516311) terminated=False\n", "step=121 reward=np.float64(0.935511214687435) terminated=False\n", "step=122 reward=np.float64(0.935511214687435) terminated=False\n", "step=123 reward=np.float64(0.935511214687435) terminated=False\n", "step=124 reward=np.float64(0.935511214687435) terminated=False\n", "step=125 reward=np.float64(0.935511214687435) terminated=False\n", "step=126 reward=np.float64(0.935511214687435) terminated=False\n", "step=127 reward=np.float64(0.935511214687435) terminated=False\n", "step=128 reward=np.float64(0.935511214687435) terminated=False\n", "step=129 reward=np.float64(0.935511214687435) terminated=False\n", "step=130 reward=np.float64(0.935511214687435) terminated=False\n", "step=131 reward=np.float64(0.935511214687435) terminated=False\n", "step=132 reward=np.float64(0.935511214687435) terminated=False\n", "step=133 reward=np.float64(0.935511214687435) terminated=False\n", "step=134 reward=np.float64(0.935511214687435) terminated=False\n", "step=135 reward=np.float64(0.9534990217822209) terminated=False\n", "step=136 reward=np.float64(0.9596585109597399) terminated=False\n", "step=137 reward=np.float64(0.882875733420291) terminated=False\n", "step=138 reward=np.float64(0.8277880190838034) terminated=False\n", "step=139 reward=np.float64(0.8211529871155266) terminated=False\n", "step=140 reward=np.float64(0.8211529871155266) terminated=False\n", "step=141 reward=np.float64(0.8211529871155266) terminated=False\n", "step=142 reward=np.float64(0.8211529871155266) terminated=False\n", "step=143 reward=np.float64(0.8211529871155266) terminated=False\n", "step=144 reward=np.float64(0.8211529871155266) terminated=False\n", "step=145 reward=np.float64(0.8211529871155266) terminated=False\n", "step=146 reward=np.float64(0.8211529871155266) terminated=False\n", "step=147 reward=np.float64(0.8211529871155266) terminated=False\n", "step=148 reward=np.float64(0.8211529871155266) terminated=False\n", "step=149 reward=np.float64(0.8211529871155266) terminated=False\n", "step=150 reward=np.float64(0.8211529871155266) terminated=False\n", "step=151 reward=np.float64(0.8211529871155266) terminated=False\n", "step=152 reward=np.float64(0.8211529871155266) terminated=False\n", "step=153 reward=np.float64(0.8211529871155266) terminated=False\n", "step=154 reward=np.float64(0.856408395120982) terminated=False\n", "step=155 reward=np.float64(0.9304040416833055) terminated=False\n", "step=156 reward=np.float64(0.9770812279113622) terminated=False\n", "step=157 reward=np.float64(0.9944597232685968) terminated=False\n", "step=158 reward=np.float64(0.9944597232685968) terminated=False\n", "step=159 reward=np.float64(0.9944597232685968) terminated=False\n", "step=160 reward=np.float64(0.9944597232685968) terminated=False\n", "step=161 reward=np.float64(0.9944597232685968) terminated=False\n", "step=162 reward=np.float64(0.9944597232685968) terminated=False\n", "step=163 reward=np.float64(0.9944597232685968) terminated=False\n", "step=164 reward=np.float64(0.9944597232685968) terminated=False\n", "step=165 reward=np.float64(0.9944597232685968) terminated=False\n", "step=166 reward=np.float64(0.9944597232685968) terminated=False\n", "step=167 reward=np.float64(0.9944597232685968) terminated=False\n", "step=168 reward=np.float64(0.9944597232685968) terminated=False\n", "step=169 reward=np.float64(0.9820541217675358) terminated=False\n", "step=170 reward=np.float64(0.8765528119646949) terminated=False\n", "step=171 reward=np.float64(0.8231919366320396) terminated=False\n", "step=172 reward=np.float64(0.7926155231821123) terminated=False\n", "step=173 reward=np.float64(0.7902960563492054) terminated=False\n", "step=174 reward=np.float64(0.7902960563492054) terminated=False\n", "step=175 reward=np.float64(0.7902960563492054) terminated=False\n", "step=176 reward=np.float64(0.7902960563492054) terminated=False\n", "step=177 reward=np.float64(0.7902960563492054) terminated=False\n", "step=178 reward=np.float64(0.7902960563492054) terminated=False\n", "step=179 reward=np.float64(0.7902960563492054) terminated=False\n", "step=180 reward=np.float64(0.7902960563492054) terminated=False\n", "step=181 reward=np.float64(0.8158199658870418) terminated=False\n", "step=182 reward=np.float64(0.8191627090126786) terminated=False\n", "step=183 reward=np.float64(0.8172224839948001) terminated=False\n", "step=184 reward=np.float64(0.8172224839948001) terminated=False\n", "step=185 reward=np.float64(0.8172224839948001) terminated=False\n", "step=186 reward=np.float64(0.8172224839948001) terminated=False\n", "step=187 reward=np.float64(0.8172224839948001) terminated=False\n", "step=188 reward=np.float64(0.8172224839948001) terminated=False\n", "step=189 reward=np.float64(0.8172224839948001) terminated=False\n", "step=190 reward=np.float64(0.8172224839948001) terminated=False\n", "step=191 reward=np.float64(0.8172224839948001) terminated=False\n", "step=192 reward=np.float64(0.8172224839948001) terminated=False\n", "step=193 reward=np.float64(0.8172224839948001) terminated=False\n", "step=194 reward=np.float64(0.8172224839948001) terminated=False\n", "step=195 reward=np.float64(0.8172224839948001) terminated=False\n", "step=196 reward=np.float64(0.8172224839948001) terminated=False\n", "step=197 reward=np.float64(0.878735078350138) terminated=False\n", "step=198 reward=np.float64(0.8564816396314117) terminated=False\n", "step=199 reward=np.float64(0.7970005244772627) terminated=False\n", "step=200 reward=np.float64(0.7688729860960439) terminated=False\n", "step=201 reward=np.float64(0.7671148163486476) terminated=False\n", "step=202 reward=np.float64(0.7671148163486476) terminated=False\n", "step=203 reward=np.float64(0.7671148163486476) terminated=False\n", "step=204 reward=np.float64(0.7671148163486476) terminated=False\n", "step=205 reward=np.float64(0.7671148163486476) terminated=False\n", "step=206 reward=np.float64(0.7671148163486476) terminated=False\n", "step=207 reward=np.float64(0.7671148163486476) terminated=False\n", "step=208 reward=np.float64(0.7671148163486476) terminated=False\n", "step=209 reward=np.float64(0.7671148163486476) terminated=False\n", "step=210 reward=np.float64(0.7671148163486476) terminated=False\n", "step=211 reward=np.float64(0.7671148163486476) terminated=False\n", "step=212 reward=np.float64(0.7671148163486476) terminated=False\n", "step=213 reward=np.float64(0.7671148163486476) terminated=False\n", "step=214 reward=np.float64(0.7671148163486476) terminated=False\n", "step=215 reward=np.float64(0.7671148163486476) terminated=False\n", "step=216 reward=np.float64(0.7671148163486476) terminated=False\n", "step=217 reward=np.float64(0.8045352949993082) terminated=False\n", "step=218 reward=np.float64(0.8328184612705187) terminated=False\n", "step=219 reward=np.float64(0.8558996801195216) terminated=False\n", "step=220 reward=np.float64(0.8576887923798919) terminated=False\n", "step=221 reward=np.float64(0.8576887923798919) terminated=False\n", "step=222 reward=np.float64(0.8576887923798919) terminated=False\n", "step=223 reward=np.float64(0.8576887923798919) terminated=False\n", "step=224 reward=np.float64(0.8576887923798919) terminated=False\n", "step=225 reward=np.float64(0.8576887923798919) terminated=False\n", "step=226 reward=np.float64(0.8576887923798919) terminated=False\n", "step=227 reward=np.float64(0.8576887923798919) terminated=False\n", "step=228 reward=np.float64(0.8576887923798919) terminated=False\n", "step=229 reward=np.float64(0.8576887923798919) terminated=False\n", "step=230 reward=np.float64(0.8576887923798919) terminated=False\n", "step=231 reward=np.float64(0.8576887923798919) terminated=False\n", "step=232 reward=np.float64(0.8576887923798919) terminated=False\n", "step=233 reward=np.float64(0.8576887923798919) terminated=False\n", "step=234 reward=np.float64(0.935454295639811) terminated=False\n", "step=235 reward=np.float64(0.9874094853870982) terminated=False\n", "step=236 reward=np.float64(0.9719847917595543) terminated=False\n", "step=237 reward=np.float64(0.9719847917595543) terminated=False\n", "step=238 reward=np.float64(0.9719847917595543) terminated=False\n", "step=239 reward=np.float64(0.9719847917595543) terminated=False\n", "step=240 reward=np.float64(0.9719847917595543) terminated=False\n", "step=241 reward=np.float64(0.9719847917595543) terminated=False\n", "step=242 reward=np.float64(0.9719847917595543) terminated=False\n", "step=243 reward=np.float64(0.9719847917595543) terminated=False\n", "step=244 reward=np.float64(0.9719847917595543) terminated=False\n", "step=245 reward=np.float64(0.9719847917595543) terminated=False\n", "step=246 reward=np.float64(0.9719847917595543) terminated=False\n", "step=247 reward=np.float64(0.9719847917595543) terminated=False\n", "step=248 reward=np.float64(0.9719847917595543) terminated=False\n", "step=249 reward=np.float64(0.9719847917595543) terminated=False\n", "step=250 reward=np.float64(0.9719847917595543) terminated=False\n", "step=251 reward=np.float64(0.9719847917595543) terminated=False\n", "step=252 reward=np.float64(0.9719847917595543) terminated=False\n", "step=253 reward=np.float64(0.9719847917595543) terminated=False\n", "step=254 reward=np.float64(0.9727790631955697) terminated=False\n", "step=255 reward=np.float64(0.946125141646605) terminated=False\n", "step=256 reward=np.float64(0.9368755165399575) terminated=False\n", "step=257 reward=np.float64(0.9360986686274937) terminated=False\n", "step=258 reward=np.float64(0.9360986686274937) terminated=False\n", "step=259 reward=np.float64(0.9360986686274937) terminated=False\n", "step=260 reward=np.float64(0.9360986686274937) terminated=False\n", "step=261 reward=np.float64(0.9360986686274937) terminated=False\n", "step=262 reward=np.float64(0.9360986686274937) terminated=False\n", "step=263 reward=np.float64(0.9360986686274937) terminated=False\n", "step=264 reward=np.float64(0.9360986686274937) terminated=False\n", "step=265 reward=np.float64(0.9360986686274937) terminated=False\n", "step=266 reward=np.float64(0.9360986686274937) terminated=False\n", "step=267 reward=np.float64(0.9360986686274937) terminated=False\n", "step=268 reward=np.float64(0.9360986686274937) terminated=False\n", "step=269 reward=np.float64(0.9360986686274937) terminated=False\n", "step=270 reward=np.float64(0.9360986686274937) terminated=False\n", "step=271 reward=np.float64(0.9360986686274937) terminated=False\n", "step=272 reward=np.float64(0.9360986686274937) terminated=False\n", "step=273 reward=np.float64(0.9360986686274937) terminated=False\n", "step=274 reward=np.float64(0.9360986686274937) terminated=False\n", "step=275 reward=np.float64(0.9360986686274937) terminated=False\n", "step=276 reward=np.float64(0.9360986686274937) terminated=False\n", "step=277 reward=np.float64(0.9360986686274937) terminated=False\n", "step=278 reward=np.float64(0.9360986686274937) terminated=False\n", "step=279 reward=np.float64(0.9360986686274937) terminated=False\n", "step=280 reward=np.float64(0.9360986686274937) terminated=False\n", "step=281 reward=np.float64(0.9360986686274937) terminated=False\n", "step=282 reward=np.float64(0.9360986686274937) terminated=False\n", "step=283 reward=np.float64(0.9360986686274937) terminated=False\n", "step=284 reward=np.float64(0.9360986686274937) terminated=False\n", "step=285 reward=np.float64(0.9360986686274937) terminated=False\n", "step=286 reward=np.float64(0.9360986686274937) terminated=False\n", "step=287 reward=np.float64(0.9360986686274937) terminated=False\n", "step=288 reward=np.float64(0.9316550466986755) terminated=False\n", "step=289 reward=np.float64(0.9218676877631473) terminated=False\n", "step=290 reward=np.float64(0.9213441513220694) terminated=False\n", "step=291 reward=np.float64(0.9213441513220694) terminated=False\n", "step=292 reward=np.float64(0.9213441513220694) terminated=False\n", "step=293 reward=np.float64(0.9213441513220694) terminated=False\n", "step=294 reward=np.float64(0.9213441513220694) terminated=False\n", "step=295 reward=np.float64(0.9213441513220694) terminated=False\n", "step=296 reward=np.float64(0.9213441513220694) terminated=False\n", "step=297 reward=np.float64(0.9213441513220694) terminated=False\n", "step=298 reward=np.float64(0.9213441513220694) terminated=False\n", "step=299 reward=np.float64(0.9213441513220694) terminated=False\n" ] } ], "source": [ "step = 0\n", "done = False\n", "\n", "while not done:\n", " # Prepare observation for the policy\n", " state = torch.from_numpy(numpy_observation[\"agent_pos\"]) # Agent position\n", " env_state = torch.from_numpy(numpy_observation[\"environment_state\"]) # Environment state\n", "\n", " # Convert to float32\n", " state = state.to(torch.float32)\n", " env_state = env_state.to(torch.float32)\n", "\n", " # Send data tensors from CPU to GPU\n", " state = state.to(device, non_blocking=True)\n", " env_state = env_state.to(device, non_blocking=True)\n", "\n", " # Add extra (empty) batch dimension, required to forward the policy\n", " state = state.unsqueeze(0)\n", " env_state = env_state.unsqueeze(0)\n", "\n", " # Create the policy input dictionary\n", " observation = {\n", " \"observation.state\": state,\n", " \"observation.environment_state\": env_state, # Add environment_state here\n", " }\n", "\n", " # Predict the next action with respect to the current observation\n", " with torch.inference_mode():\n", " action = policy.select_action(observation)\n", "\n", " # Prepare the action for the environment\n", " numpy_action = action.squeeze(0).to(\"cpu\").numpy()\n", "\n", " # Step through the environment and receive a new observation\n", " numpy_observation, reward, terminated, truncated, info = env.step(numpy_action)\n", " print(f\"{step=} {reward=} {terminated=}\")\n", "\n", " # Keep track of all the rewards and frames\n", " rewards.append(reward)\n", " frames.append(env.render())\n", "\n", " # The rollout is considered done when the success state is reached (i.e. terminated is True),\n", " # or the maximum number of iterations is reached (i.e. truncated is True)\n", " done = terminated or truncated or done\n", " step += 1\n" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Failure!\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (680, 680) to (688, 688) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to 1 (risking incompatibility).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Video of the evaluation is available in '/home/lerobot/output/rollout.mp4'.\n" ] } ], "source": [ "if terminated:\n", " print(\"Success!\")\n", "else:\n", " print(\"Failure!\")\n", "\n", "# Get the speed of environment (i.e. its number of frames per second).\n", "fps = env.metadata[\"render_fps\"]\n", "\n", "# Encode all frames into a mp4 video.\n", "video_path = \"/home/lerobot/output/rollout.mp4\"\n", "imageio.mimsave(str(video_path), numpy.stack(frames), fps=fps)\n", "\n", "print(f\"Video of the evaluation is available in '{video_path}'.\")" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "#now on aloha\n", "import imageio\n", "import gymnasium as gym\n", "import numpy as np\n", "import gym_aloha\n", "env = gym.make(\n", " \"gym_aloha/AlohaInsertion-v0\",\n", " obs_type=\"pixels\",\n", " max_episode_steps=300,\n", ")" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/envs/lerobot/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "Downloading: \"https://download.pytorch.org/models/resnet18-f37072fd.pth\" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth\n", "100%|██████████| 44.7M/44.7M [00:00<00:00, 172MB/s]\n" ] } ], "source": [ "from lerobot.common.policies.dot.modeling_dot import DOTPolicy\n", "pretrained_policy_path = \"IliaLarchenko/dot_bimanual_insert\"\n", "policy = DOTPolicy.from_pretrained(pretrained_policy_path, map_location=device)\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'observation.images.top': PolicyFeature(type=, shape=(3, 480, 640)), 'observation.state': PolicyFeature(type=, shape=(14,))}\n", "Dict('top': Box(0, 255, (480, 640, 3), uint8))\n" ] } ], "source": [ "# We can verify that the shapes of the features expected by the policy match the ones from the observations\n", "# produced by the environment\n", "print(policy.config.input_features)\n", "print(env.observation_space)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'action': PolicyFeature(type=, shape=(14,))}\n", "Box(-1.0, 1.0, (14,), float32)\n" ] } ], "source": [ "# Similarly, we can check that the actions produced by the policy will match the actions expected by the\n", "# environment\n", "print(policy.config.output_features)\n", "print(env.action_space)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "ename": "FatalError", "evalue": "gladLoadGL error", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mFatalError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[16], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Reset the policy and environments to prepare for rollout\u001b[39;00m\n\u001b[1;32m 2\u001b[0m policy\u001b[38;5;241m.\u001b[39mreset()\n\u001b[0;32m----> 3\u001b[0m numpy_observation, info \u001b[38;5;241m=\u001b[39m \u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m(\u001b[49m\u001b[43mseed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m42\u001b[39;49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gymnasium/wrappers/time_limit.py:75\u001b[0m, in \u001b[0;36mTimeLimit.reset\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 66\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero.\u001b[39;00m\n\u001b[1;32m 67\u001b[0m \n\u001b[1;32m 68\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;124;03m The reset environment\u001b[39;00m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_elapsed_steps \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m---> 75\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gymnasium/wrappers/order_enforcing.py:61\u001b[0m, in \u001b[0;36mOrderEnforcing.reset\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Resets the environment with `kwargs`.\"\"\"\u001b[39;00m\n\u001b[1;32m 60\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_has_reset \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m---> 61\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gymnasium/wrappers/env_checker.py:57\u001b[0m, in \u001b[0;36mPassiveEnvChecker.reset\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchecked_reset:\n\u001b[1;32m 56\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchecked_reset \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m---> 57\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43menv_reset_passive_checker\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 58\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 59\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39menv\u001b[38;5;241m.\u001b[39mreset(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", "File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gymnasium/utils/passive_env_checker.py:186\u001b[0m, in \u001b[0;36menv_reset_passive_checker\u001b[0;34m(env, **kwargs)\u001b[0m\n\u001b[1;32m 181\u001b[0m logger\u001b[38;5;241m.\u001b[39mdeprecation(\n\u001b[1;32m 182\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCurrent gymnasium version requires that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 183\u001b[0m )\n\u001b[1;32m 185\u001b[0m \u001b[38;5;66;03m# Checks the result of env.reset with kwargs\u001b[39;00m\n\u001b[0;32m--> 186\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(result, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[1;32m 189\u001b[0m logger\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[1;32m 190\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(result)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m`\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 191\u001b[0m )\n", "File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gym_aloha/env.py:166\u001b[0m, in \u001b[0;36mAlohaEnv.reset\u001b[0;34m(self, seed, options)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtask)\n\u001b[0;32m--> 166\u001b[0m raw_obs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_env\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m observation \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_raw_obs(raw_obs\u001b[38;5;241m.\u001b[39mobservation)\n\u001b[1;32m 170\u001b[0m info \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mis_success\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mFalse\u001b[39;00m}\n", "File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/rl/control.py:89\u001b[0m, in \u001b[0;36mEnvironment.reset\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_physics\u001b[38;5;241m.\u001b[39mreset_context():\n\u001b[1;32m 87\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_task\u001b[38;5;241m.\u001b[39minitialize_episode(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_physics)\n\u001b[0;32m---> 89\u001b[0m observation \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_task\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_observation\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_physics\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 90\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_flat_observation:\n\u001b[1;32m 91\u001b[0m observation \u001b[38;5;241m=\u001b[39m flatten_observation(observation)\n", "File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gym_aloha/tasks/sim.py:92\u001b[0m, in \u001b[0;36mBimanualViperXTask.get_observation\u001b[0;34m(self, physics)\u001b[0m\n\u001b[1;32m 90\u001b[0m obs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124menv_state\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_env_state(physics)\n\u001b[1;32m 91\u001b[0m obs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimages\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m {}\n\u001b[0;32m---> 92\u001b[0m obs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimages\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtop\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[43mphysics\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrender\u001b[49m\u001b[43m(\u001b[49m\u001b[43mheight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m480\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwidth\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m640\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcamera_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtop\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 93\u001b[0m obs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimages\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mangle\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m physics\u001b[38;5;241m.\u001b[39mrender(height\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m480\u001b[39m, width\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m640\u001b[39m, camera_id\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mangle\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 94\u001b[0m obs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimages\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvis\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m physics\u001b[38;5;241m.\u001b[39mrender(height\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m480\u001b[39m, width\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m640\u001b[39m, camera_id\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfront_close\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", "File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/mujoco/engine.py:223\u001b[0m, in \u001b[0;36mPhysics.render\u001b[0;34m(self, height, width, camera_id, overlays, depth, segmentation, scene_option, render_flag_overrides, scene_callback)\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mrender\u001b[39m(\n\u001b[1;32m 179\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 180\u001b[0m height\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m240\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;28;01mNone\u001b[39;00m]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 190\u001b[0m ):\n\u001b[1;32m 191\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Returns a camera view as a NumPy array of pixel values.\u001b[39;00m\n\u001b[1;32m 192\u001b[0m \n\u001b[1;32m 193\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 221\u001b[0m \u001b[38;5;124;03m The rendered RGB, depth or segmentation image.\u001b[39;00m\n\u001b[1;32m 222\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 223\u001b[0m camera \u001b[38;5;241m=\u001b[39m \u001b[43mCamera\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 224\u001b[0m \u001b[43m \u001b[49m\u001b[43mphysics\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 225\u001b[0m \u001b[43m \u001b[49m\u001b[43mheight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 226\u001b[0m \u001b[43m \u001b[49m\u001b[43mwidth\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwidth\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 227\u001b[0m \u001b[43m \u001b[49m\u001b[43mcamera_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcamera_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 228\u001b[0m \u001b[43m \u001b[49m\u001b[43mscene_callback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mscene_callback\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 229\u001b[0m image \u001b[38;5;241m=\u001b[39m camera\u001b[38;5;241m.\u001b[39mrender(\n\u001b[1;32m 230\u001b[0m overlays\u001b[38;5;241m=\u001b[39moverlays, depth\u001b[38;5;241m=\u001b[39mdepth, segmentation\u001b[38;5;241m=\u001b[39msegmentation,\n\u001b[1;32m 231\u001b[0m scene_option\u001b[38;5;241m=\u001b[39mscene_option, render_flag_overrides\u001b[38;5;241m=\u001b[39mrender_flag_overrides)\n\u001b[1;32m 232\u001b[0m camera\u001b[38;5;241m.\u001b[39m_scene\u001b[38;5;241m.\u001b[39mfree() \u001b[38;5;66;03m# pylint: disable=protected-access\u001b[39;00m\n", "File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/mujoco/engine.py:711\u001b[0m, in \u001b[0;36mCamera.__init__\u001b[0;34m(self, physics, height, width, camera_id, max_geom, scene_callback)\u001b[0m\n\u001b[1;32m 708\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_rgb_buffer \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mempty((\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_height, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_width, \u001b[38;5;241m3\u001b[39m), dtype\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39muint8)\n\u001b[1;32m 709\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_depth_buffer \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mempty((\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_height, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_width), dtype\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39mfloat32)\n\u001b[0;32m--> 711\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_physics\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcontexts\u001b[49m\u001b[38;5;241m.\u001b[39mmujoco \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 712\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_physics\u001b[38;5;241m.\u001b[39mcontexts\u001b[38;5;241m.\u001b[39mgl\u001b[38;5;241m.\u001b[39mmake_current() \u001b[38;5;28;01mas\u001b[39;00m ctx:\n\u001b[1;32m 713\u001b[0m ctx\u001b[38;5;241m.\u001b[39mcall(mujoco\u001b[38;5;241m.\u001b[39mmjr_setBuffer, mujoco\u001b[38;5;241m.\u001b[39mmjtFramebuffer\u001b[38;5;241m.\u001b[39mmjFB_OFFSCREEN,\n\u001b[1;32m 714\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_physics\u001b[38;5;241m.\u001b[39mcontexts\u001b[38;5;241m.\u001b[39mmujoco\u001b[38;5;241m.\u001b[39mptr)\n", "File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/mujoco/engine.py:533\u001b[0m, in \u001b[0;36mPhysics.contexts\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 531\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_contexts_lock:\n\u001b[1;32m 532\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_contexts:\n\u001b[0;32m--> 533\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_make_rendering_contexts\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 534\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_contexts\n", "File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/mujoco/engine.py:519\u001b[0m, in \u001b[0;36mPhysics._make_rendering_contexts\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 516\u001b[0m render_context \u001b[38;5;241m=\u001b[39m _render\u001b[38;5;241m.\u001b[39mRenderer(\n\u001b[1;32m 517\u001b[0m max_width\u001b[38;5;241m=\u001b[39mmax_width, max_height\u001b[38;5;241m=\u001b[39mmax_height)\n\u001b[1;32m 518\u001b[0m \u001b[38;5;66;03m# Create the MuJoCo context.\u001b[39;00m\n\u001b[0;32m--> 519\u001b[0m mujoco_context \u001b[38;5;241m=\u001b[39m \u001b[43mwrapper\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mMjrContext\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrender_context\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 520\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_contexts \u001b[38;5;241m=\u001b[39m Contexts(gl\u001b[38;5;241m=\u001b[39mrender_context, mujoco\u001b[38;5;241m=\u001b[39mmujoco_context)\n", "File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/mujoco/wrapper/core.py:603\u001b[0m, in \u001b[0;36mMjrContext.__init__\u001b[0;34m(self, model, gl_context, font_scale)\u001b[0m\n\u001b[1;32m 601\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gl_context \u001b[38;5;241m=\u001b[39m gl_context\n\u001b[1;32m 602\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m gl_context\u001b[38;5;241m.\u001b[39mmake_current() \u001b[38;5;28;01mas\u001b[39;00m ctx:\n\u001b[0;32m--> 603\u001b[0m ptr \u001b[38;5;241m=\u001b[39m \u001b[43mctx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcall\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmujoco\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mMjrContext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mptr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfont_scale\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 604\u001b[0m ctx\u001b[38;5;241m.\u001b[39mcall(mujoco\u001b[38;5;241m.\u001b[39mmjr_setBuffer, mujoco\u001b[38;5;241m.\u001b[39mmjtFramebuffer\u001b[38;5;241m.\u001b[39mmjFB_OFFSCREEN, ptr)\n\u001b[1;32m 605\u001b[0m gl_context\u001b[38;5;241m.\u001b[39mkeep_alive(ptr)\n", "File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/_render/executor/render_executor.py:138\u001b[0m, in \u001b[0;36mPassthroughRenderExecutor.call\u001b[0;34m(self, func, *args, **kwargs)\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mcall\u001b[39m(\u001b[38;5;28mself\u001b[39m, func, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 137\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_locked()\n\u001b[0;32m--> 138\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[0;31mFatalError\u001b[0m: gladLoadGL error" ] } ], "source": [ "# Reset the policy and environments to prepare for rollout\n", "policy.reset()\n", "numpy_observation, info = env.reset(seed=42)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "lerobot", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 2 }