675 lines
50 KiB
Plaintext
675 lines
50 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from pathlib import Path\n",
|
|
"\n",
|
|
"import gym_pusht # noqa: F401\n",
|
|
"import gymnasium as gym\n",
|
|
"import imageio\n",
|
|
"import numpy\n",
|
|
"import torch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Select your device\n",
|
|
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Provide the [hugging face repo id](https://huggingface.co/lerobot/diffusion_pusht):\n",
|
|
"pretrained_policy_path = \"IliaLarchenko/dot_pusht_keypoints\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from lerobot.common.policies.dot.modeling_dot import DOTPolicy\n",
|
|
"policy = DOTPolicy.from_pretrained(pretrained_policy_path, map_location=device)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"env = gym.make(\n",
|
|
" \"gym_pusht/PushT-v0\",\n",
|
|
" obs_type=\"environment_state_agent_pos\",\n",
|
|
" max_episode_steps=300,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"{'observation.state': PolicyFeature(type=<FeatureType.STATE: 'STATE'>, shape=(2,)), 'observation.environment_state': PolicyFeature(type=<FeatureType.ENV: 'ENV'>, shape=(16,))}\n",
|
|
"Dict('agent_pos': Box(0.0, 512.0, (2,), float64), 'environment_state': Box(0.0, 512.0, (16,), float64))\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(policy.config.input_features)\n",
|
|
"print(env.observation_space)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"{'action': PolicyFeature(type=<FeatureType.ACTION: 'ACTION'>, shape=(2,))}\n",
|
|
"Box(0.0, 512.0, (2,), float32)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(policy.config.output_features)\n",
|
|
"print(env.action_space)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"policy.reset()\n",
|
|
"numpy_observation, info = env.reset(seed=42)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Prepare to collect every rewards and all the frames of the episode,\n",
|
|
"# from initial state to final state.\n",
|
|
"rewards = []\n",
|
|
"frames = []\n",
|
|
"\n",
|
|
"# Render frame of the initial state\n",
|
|
"frames.append(env.render())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"step=0 reward=np.float64(0.0) terminated=False\n",
|
|
"step=1 reward=np.float64(0.0) terminated=False\n",
|
|
"step=2 reward=np.float64(0.0) terminated=False\n",
|
|
"step=3 reward=np.float64(0.0) terminated=False\n",
|
|
"step=4 reward=np.float64(0.0) terminated=False\n",
|
|
"step=5 reward=np.float64(0.0) terminated=False\n",
|
|
"step=6 reward=np.float64(0.0) terminated=False\n",
|
|
"step=7 reward=np.float64(0.0) terminated=False\n",
|
|
"step=8 reward=np.float64(0.0) terminated=False\n",
|
|
"step=9 reward=np.float64(0.0) terminated=False\n",
|
|
"step=10 reward=np.float64(0.0) terminated=False\n",
|
|
"step=11 reward=np.float64(0.0) terminated=False\n",
|
|
"step=12 reward=np.float64(0.0) terminated=False\n",
|
|
"step=13 reward=np.float64(0.0) terminated=False\n",
|
|
"step=14 reward=np.float64(0.0) terminated=False\n",
|
|
"step=15 reward=np.float64(0.0) terminated=False\n",
|
|
"step=16 reward=np.float64(0.0) terminated=False\n",
|
|
"step=17 reward=np.float64(0.0) terminated=False\n",
|
|
"step=18 reward=np.float64(0.0) terminated=False\n",
|
|
"step=19 reward=np.float64(0.0) terminated=False\n",
|
|
"step=20 reward=np.float64(0.0) terminated=False\n",
|
|
"step=21 reward=np.float64(0.0) terminated=False\n",
|
|
"step=22 reward=np.float64(0.0009941544780861455) terminated=False\n",
|
|
"step=23 reward=np.float64(0.033647507519038757) terminated=False\n",
|
|
"step=24 reward=np.float64(0.07026086006261555) terminated=False\n",
|
|
"step=25 reward=np.float64(0.10069667553409196) terminated=False\n",
|
|
"step=26 reward=np.float64(0.11389926069925992) terminated=False\n",
|
|
"step=27 reward=np.float64(0.12027077768723497) terminated=False\n",
|
|
"step=28 reward=np.float64(0.12486582623684722) terminated=False\n",
|
|
"step=29 reward=np.float64(0.12815916861048604) terminated=False\n",
|
|
"step=30 reward=np.float64(0.1303391815805222) terminated=False\n",
|
|
"step=31 reward=np.float64(0.1315231117258188) terminated=False\n",
|
|
"step=32 reward=np.float64(0.13221640549835664) terminated=False\n",
|
|
"step=33 reward=np.float64(0.13254763259209015) terminated=False\n",
|
|
"step=34 reward=np.float64(0.13263558368425837) terminated=False\n",
|
|
"step=35 reward=np.float64(0.13263932937572084) terminated=False\n",
|
|
"step=36 reward=np.float64(0.13263932937572084) terminated=False\n",
|
|
"step=37 reward=np.float64(0.13263932937572084) terminated=False\n",
|
|
"step=38 reward=np.float64(0.13263932937572084) terminated=False\n",
|
|
"step=39 reward=np.float64(0.13263932937572084) terminated=False\n",
|
|
"step=40 reward=np.float64(0.13263932937572084) terminated=False\n",
|
|
"step=41 reward=np.float64(0.13263932937572084) terminated=False\n",
|
|
"step=42 reward=np.float64(0.13263932937572084) terminated=False\n",
|
|
"step=43 reward=np.float64(0.13263932937572084) terminated=False\n",
|
|
"step=44 reward=np.float64(0.13263932937572084) terminated=False\n",
|
|
"step=45 reward=np.float64(0.13263932937572084) terminated=False\n",
|
|
"step=46 reward=np.float64(0.13263932937572084) terminated=False\n",
|
|
"step=47 reward=np.float64(0.13263932937572084) terminated=False\n",
|
|
"step=48 reward=np.float64(0.13263932937572084) terminated=False\n",
|
|
"step=49 reward=np.float64(0.13263932937572084) terminated=False\n",
|
|
"step=50 reward=np.float64(0.13263932937572084) terminated=False\n",
|
|
"step=51 reward=np.float64(0.13263932937572084) terminated=False\n",
|
|
"step=52 reward=np.float64(0.14872285364307145) terminated=False\n",
|
|
"step=53 reward=np.float64(0.19847005044261715) terminated=False\n",
|
|
"step=54 reward=np.float64(0.24338272205852812) terminated=False\n",
|
|
"step=55 reward=np.float64(0.2667243347061481) terminated=False\n",
|
|
"step=56 reward=np.float64(0.2691675276421592) terminated=False\n",
|
|
"step=57 reward=np.float64(0.3018254762158707) terminated=False\n",
|
|
"step=58 reward=np.float64(0.3613501686331564) terminated=False\n",
|
|
"step=59 reward=np.float64(0.4613512243665896) terminated=False\n",
|
|
"step=60 reward=np.float64(0.5617656688929643) terminated=False\n",
|
|
"step=61 reward=np.float64(0.5747609180351871) terminated=False\n",
|
|
"step=62 reward=np.float64(0.507048651118485) terminated=False\n",
|
|
"step=63 reward=np.float64(0.44332042287270484) terminated=False\n",
|
|
"step=64 reward=np.float64(0.3993804222553378) terminated=False\n",
|
|
"step=65 reward=np.float64(0.36941592278664487) terminated=False\n",
|
|
"step=66 reward=np.float64(0.36941592278664487) terminated=False\n",
|
|
"step=67 reward=np.float64(0.36941592278664487) terminated=False\n",
|
|
"step=68 reward=np.float64(0.36941592278664487) terminated=False\n",
|
|
"step=69 reward=np.float64(0.36941592278664487) terminated=False\n",
|
|
"step=70 reward=np.float64(0.36941592278664487) terminated=False\n",
|
|
"step=71 reward=np.float64(0.36941592278664487) terminated=False\n",
|
|
"step=72 reward=np.float64(0.36941592278664487) terminated=False\n",
|
|
"step=73 reward=np.float64(0.36941592278664487) terminated=False\n",
|
|
"step=74 reward=np.float64(0.36941592278664487) terminated=False\n",
|
|
"step=75 reward=np.float64(0.4322328474940646) terminated=False\n",
|
|
"step=76 reward=np.float64(0.4818152566968738) terminated=False\n",
|
|
"step=77 reward=np.float64(0.5252535051167734) terminated=False\n",
|
|
"step=78 reward=np.float64(0.5586446249197407) terminated=False\n",
|
|
"step=79 reward=np.float64(0.5885022076599307) terminated=False\n",
|
|
"step=80 reward=np.float64(0.5977994643852952) terminated=False\n",
|
|
"step=81 reward=np.float64(0.597859201570885) terminated=False\n",
|
|
"step=82 reward=np.float64(0.597859201570885) terminated=False\n",
|
|
"step=83 reward=np.float64(0.597859201570885) terminated=False\n",
|
|
"step=84 reward=np.float64(0.597859201570885) terminated=False\n",
|
|
"step=85 reward=np.float64(0.597859201570885) terminated=False\n",
|
|
"step=86 reward=np.float64(0.597859201570885) terminated=False\n",
|
|
"step=87 reward=np.float64(0.597859201570885) terminated=False\n",
|
|
"step=88 reward=np.float64(0.597859201570885) terminated=False\n",
|
|
"step=89 reward=np.float64(0.6876341127908622) terminated=False\n",
|
|
"step=90 reward=np.float64(0.8166289152424572) terminated=False\n",
|
|
"step=91 reward=np.float64(0.9421614978354362) terminated=False\n",
|
|
"step=92 reward=np.float64(0.9441608976568224) terminated=False\n",
|
|
"step=93 reward=np.float64(0.9104163604296966) terminated=False\n",
|
|
"step=94 reward=np.float64(0.909182661371769) terminated=False\n",
|
|
"step=95 reward=np.float64(0.909182661371769) terminated=False\n",
|
|
"step=96 reward=np.float64(0.909182661371769) terminated=False\n",
|
|
"step=97 reward=np.float64(0.909182661371769) terminated=False\n",
|
|
"step=98 reward=np.float64(0.909182661371769) terminated=False\n",
|
|
"step=99 reward=np.float64(0.909182661371769) terminated=False\n",
|
|
"step=100 reward=np.float64(0.909182661371769) terminated=False\n",
|
|
"step=101 reward=np.float64(0.909182661371769) terminated=False\n",
|
|
"step=102 reward=np.float64(0.909182661371769) terminated=False\n",
|
|
"step=103 reward=np.float64(0.9340357871805705) terminated=False\n",
|
|
"step=104 reward=np.float64(0.8851102121651142) terminated=False\n",
|
|
"step=105 reward=np.float64(0.8809768749693764) terminated=False\n",
|
|
"step=106 reward=np.float64(0.8809768749693764) terminated=False\n",
|
|
"step=107 reward=np.float64(0.8809768749693764) terminated=False\n",
|
|
"step=108 reward=np.float64(0.8809768749693764) terminated=False\n",
|
|
"step=109 reward=np.float64(0.8809768749693764) terminated=False\n",
|
|
"step=110 reward=np.float64(0.8809768749693764) terminated=False\n",
|
|
"step=111 reward=np.float64(0.8809768749693764) terminated=False\n",
|
|
"step=112 reward=np.float64(0.8809768749693764) terminated=False\n",
|
|
"step=113 reward=np.float64(0.8809768749693764) terminated=False\n",
|
|
"step=114 reward=np.float64(0.8809768749693764) terminated=False\n",
|
|
"step=115 reward=np.float64(0.8809768749693764) terminated=False\n",
|
|
"step=116 reward=np.float64(0.8809768749693764) terminated=False\n",
|
|
"step=117 reward=np.float64(0.8809768749693764) terminated=False\n",
|
|
"step=118 reward=np.float64(0.8809768749693764) terminated=False\n",
|
|
"step=119 reward=np.float64(0.9518089158714292) terminated=False\n",
|
|
"step=120 reward=np.float64(0.9405458729516311) terminated=False\n",
|
|
"step=121 reward=np.float64(0.935511214687435) terminated=False\n",
|
|
"step=122 reward=np.float64(0.935511214687435) terminated=False\n",
|
|
"step=123 reward=np.float64(0.935511214687435) terminated=False\n",
|
|
"step=124 reward=np.float64(0.935511214687435) terminated=False\n",
|
|
"step=125 reward=np.float64(0.935511214687435) terminated=False\n",
|
|
"step=126 reward=np.float64(0.935511214687435) terminated=False\n",
|
|
"step=127 reward=np.float64(0.935511214687435) terminated=False\n",
|
|
"step=128 reward=np.float64(0.935511214687435) terminated=False\n",
|
|
"step=129 reward=np.float64(0.935511214687435) terminated=False\n",
|
|
"step=130 reward=np.float64(0.935511214687435) terminated=False\n",
|
|
"step=131 reward=np.float64(0.935511214687435) terminated=False\n",
|
|
"step=132 reward=np.float64(0.935511214687435) terminated=False\n",
|
|
"step=133 reward=np.float64(0.935511214687435) terminated=False\n",
|
|
"step=134 reward=np.float64(0.935511214687435) terminated=False\n",
|
|
"step=135 reward=np.float64(0.9534990217822209) terminated=False\n",
|
|
"step=136 reward=np.float64(0.9596585109597399) terminated=False\n",
|
|
"step=137 reward=np.float64(0.882875733420291) terminated=False\n",
|
|
"step=138 reward=np.float64(0.8277880190838034) terminated=False\n",
|
|
"step=139 reward=np.float64(0.8211529871155266) terminated=False\n",
|
|
"step=140 reward=np.float64(0.8211529871155266) terminated=False\n",
|
|
"step=141 reward=np.float64(0.8211529871155266) terminated=False\n",
|
|
"step=142 reward=np.float64(0.8211529871155266) terminated=False\n",
|
|
"step=143 reward=np.float64(0.8211529871155266) terminated=False\n",
|
|
"step=144 reward=np.float64(0.8211529871155266) terminated=False\n",
|
|
"step=145 reward=np.float64(0.8211529871155266) terminated=False\n",
|
|
"step=146 reward=np.float64(0.8211529871155266) terminated=False\n",
|
|
"step=147 reward=np.float64(0.8211529871155266) terminated=False\n",
|
|
"step=148 reward=np.float64(0.8211529871155266) terminated=False\n",
|
|
"step=149 reward=np.float64(0.8211529871155266) terminated=False\n",
|
|
"step=150 reward=np.float64(0.8211529871155266) terminated=False\n",
|
|
"step=151 reward=np.float64(0.8211529871155266) terminated=False\n",
|
|
"step=152 reward=np.float64(0.8211529871155266) terminated=False\n",
|
|
"step=153 reward=np.float64(0.8211529871155266) terminated=False\n",
|
|
"step=154 reward=np.float64(0.856408395120982) terminated=False\n",
|
|
"step=155 reward=np.float64(0.9304040416833055) terminated=False\n",
|
|
"step=156 reward=np.float64(0.9770812279113622) terminated=False\n",
|
|
"step=157 reward=np.float64(0.9944597232685968) terminated=False\n",
|
|
"step=158 reward=np.float64(0.9944597232685968) terminated=False\n",
|
|
"step=159 reward=np.float64(0.9944597232685968) terminated=False\n",
|
|
"step=160 reward=np.float64(0.9944597232685968) terminated=False\n",
|
|
"step=161 reward=np.float64(0.9944597232685968) terminated=False\n",
|
|
"step=162 reward=np.float64(0.9944597232685968) terminated=False\n",
|
|
"step=163 reward=np.float64(0.9944597232685968) terminated=False\n",
|
|
"step=164 reward=np.float64(0.9944597232685968) terminated=False\n",
|
|
"step=165 reward=np.float64(0.9944597232685968) terminated=False\n",
|
|
"step=166 reward=np.float64(0.9944597232685968) terminated=False\n",
|
|
"step=167 reward=np.float64(0.9944597232685968) terminated=False\n",
|
|
"step=168 reward=np.float64(0.9944597232685968) terminated=False\n",
|
|
"step=169 reward=np.float64(0.9820541217675358) terminated=False\n",
|
|
"step=170 reward=np.float64(0.8765528119646949) terminated=False\n",
|
|
"step=171 reward=np.float64(0.8231919366320396) terminated=False\n",
|
|
"step=172 reward=np.float64(0.7926155231821123) terminated=False\n",
|
|
"step=173 reward=np.float64(0.7902960563492054) terminated=False\n",
|
|
"step=174 reward=np.float64(0.7902960563492054) terminated=False\n",
|
|
"step=175 reward=np.float64(0.7902960563492054) terminated=False\n",
|
|
"step=176 reward=np.float64(0.7902960563492054) terminated=False\n",
|
|
"step=177 reward=np.float64(0.7902960563492054) terminated=False\n",
|
|
"step=178 reward=np.float64(0.7902960563492054) terminated=False\n",
|
|
"step=179 reward=np.float64(0.7902960563492054) terminated=False\n",
|
|
"step=180 reward=np.float64(0.7902960563492054) terminated=False\n",
|
|
"step=181 reward=np.float64(0.8158199658870418) terminated=False\n",
|
|
"step=182 reward=np.float64(0.8191627090126786) terminated=False\n",
|
|
"step=183 reward=np.float64(0.8172224839948001) terminated=False\n",
|
|
"step=184 reward=np.float64(0.8172224839948001) terminated=False\n",
|
|
"step=185 reward=np.float64(0.8172224839948001) terminated=False\n",
|
|
"step=186 reward=np.float64(0.8172224839948001) terminated=False\n",
|
|
"step=187 reward=np.float64(0.8172224839948001) terminated=False\n",
|
|
"step=188 reward=np.float64(0.8172224839948001) terminated=False\n",
|
|
"step=189 reward=np.float64(0.8172224839948001) terminated=False\n",
|
|
"step=190 reward=np.float64(0.8172224839948001) terminated=False\n",
|
|
"step=191 reward=np.float64(0.8172224839948001) terminated=False\n",
|
|
"step=192 reward=np.float64(0.8172224839948001) terminated=False\n",
|
|
"step=193 reward=np.float64(0.8172224839948001) terminated=False\n",
|
|
"step=194 reward=np.float64(0.8172224839948001) terminated=False\n",
|
|
"step=195 reward=np.float64(0.8172224839948001) terminated=False\n",
|
|
"step=196 reward=np.float64(0.8172224839948001) terminated=False\n",
|
|
"step=197 reward=np.float64(0.878735078350138) terminated=False\n",
|
|
"step=198 reward=np.float64(0.8564816396314117) terminated=False\n",
|
|
"step=199 reward=np.float64(0.7970005244772627) terminated=False\n",
|
|
"step=200 reward=np.float64(0.7688729860960439) terminated=False\n",
|
|
"step=201 reward=np.float64(0.7671148163486476) terminated=False\n",
|
|
"step=202 reward=np.float64(0.7671148163486476) terminated=False\n",
|
|
"step=203 reward=np.float64(0.7671148163486476) terminated=False\n",
|
|
"step=204 reward=np.float64(0.7671148163486476) terminated=False\n",
|
|
"step=205 reward=np.float64(0.7671148163486476) terminated=False\n",
|
|
"step=206 reward=np.float64(0.7671148163486476) terminated=False\n",
|
|
"step=207 reward=np.float64(0.7671148163486476) terminated=False\n",
|
|
"step=208 reward=np.float64(0.7671148163486476) terminated=False\n",
|
|
"step=209 reward=np.float64(0.7671148163486476) terminated=False\n",
|
|
"step=210 reward=np.float64(0.7671148163486476) terminated=False\n",
|
|
"step=211 reward=np.float64(0.7671148163486476) terminated=False\n",
|
|
"step=212 reward=np.float64(0.7671148163486476) terminated=False\n",
|
|
"step=213 reward=np.float64(0.7671148163486476) terminated=False\n",
|
|
"step=214 reward=np.float64(0.7671148163486476) terminated=False\n",
|
|
"step=215 reward=np.float64(0.7671148163486476) terminated=False\n",
|
|
"step=216 reward=np.float64(0.7671148163486476) terminated=False\n",
|
|
"step=217 reward=np.float64(0.8045352949993082) terminated=False\n",
|
|
"step=218 reward=np.float64(0.8328184612705187) terminated=False\n",
|
|
"step=219 reward=np.float64(0.8558996801195216) terminated=False\n",
|
|
"step=220 reward=np.float64(0.8576887923798919) terminated=False\n",
|
|
"step=221 reward=np.float64(0.8576887923798919) terminated=False\n",
|
|
"step=222 reward=np.float64(0.8576887923798919) terminated=False\n",
|
|
"step=223 reward=np.float64(0.8576887923798919) terminated=False\n",
|
|
"step=224 reward=np.float64(0.8576887923798919) terminated=False\n",
|
|
"step=225 reward=np.float64(0.8576887923798919) terminated=False\n",
|
|
"step=226 reward=np.float64(0.8576887923798919) terminated=False\n",
|
|
"step=227 reward=np.float64(0.8576887923798919) terminated=False\n",
|
|
"step=228 reward=np.float64(0.8576887923798919) terminated=False\n",
|
|
"step=229 reward=np.float64(0.8576887923798919) terminated=False\n",
|
|
"step=230 reward=np.float64(0.8576887923798919) terminated=False\n",
|
|
"step=231 reward=np.float64(0.8576887923798919) terminated=False\n",
|
|
"step=232 reward=np.float64(0.8576887923798919) terminated=False\n",
|
|
"step=233 reward=np.float64(0.8576887923798919) terminated=False\n",
|
|
"step=234 reward=np.float64(0.935454295639811) terminated=False\n",
|
|
"step=235 reward=np.float64(0.9874094853870982) terminated=False\n",
|
|
"step=236 reward=np.float64(0.9719847917595543) terminated=False\n",
|
|
"step=237 reward=np.float64(0.9719847917595543) terminated=False\n",
|
|
"step=238 reward=np.float64(0.9719847917595543) terminated=False\n",
|
|
"step=239 reward=np.float64(0.9719847917595543) terminated=False\n",
|
|
"step=240 reward=np.float64(0.9719847917595543) terminated=False\n",
|
|
"step=241 reward=np.float64(0.9719847917595543) terminated=False\n",
|
|
"step=242 reward=np.float64(0.9719847917595543) terminated=False\n",
|
|
"step=243 reward=np.float64(0.9719847917595543) terminated=False\n",
|
|
"step=244 reward=np.float64(0.9719847917595543) terminated=False\n",
|
|
"step=245 reward=np.float64(0.9719847917595543) terminated=False\n",
|
|
"step=246 reward=np.float64(0.9719847917595543) terminated=False\n",
|
|
"step=247 reward=np.float64(0.9719847917595543) terminated=False\n",
|
|
"step=248 reward=np.float64(0.9719847917595543) terminated=False\n",
|
|
"step=249 reward=np.float64(0.9719847917595543) terminated=False\n",
|
|
"step=250 reward=np.float64(0.9719847917595543) terminated=False\n",
|
|
"step=251 reward=np.float64(0.9719847917595543) terminated=False\n",
|
|
"step=252 reward=np.float64(0.9719847917595543) terminated=False\n",
|
|
"step=253 reward=np.float64(0.9719847917595543) terminated=False\n",
|
|
"step=254 reward=np.float64(0.9727790631955697) terminated=False\n",
|
|
"step=255 reward=np.float64(0.946125141646605) terminated=False\n",
|
|
"step=256 reward=np.float64(0.9368755165399575) terminated=False\n",
|
|
"step=257 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=258 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=259 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=260 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=261 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=262 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=263 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=264 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=265 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=266 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=267 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=268 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=269 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=270 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=271 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=272 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=273 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=274 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=275 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=276 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=277 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=278 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=279 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=280 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=281 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=282 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=283 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=284 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=285 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=286 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=287 reward=np.float64(0.9360986686274937) terminated=False\n",
|
|
"step=288 reward=np.float64(0.9316550466986755) terminated=False\n",
|
|
"step=289 reward=np.float64(0.9218676877631473) terminated=False\n",
|
|
"step=290 reward=np.float64(0.9213441513220694) terminated=False\n",
|
|
"step=291 reward=np.float64(0.9213441513220694) terminated=False\n",
|
|
"step=292 reward=np.float64(0.9213441513220694) terminated=False\n",
|
|
"step=293 reward=np.float64(0.9213441513220694) terminated=False\n",
|
|
"step=294 reward=np.float64(0.9213441513220694) terminated=False\n",
|
|
"step=295 reward=np.float64(0.9213441513220694) terminated=False\n",
|
|
"step=296 reward=np.float64(0.9213441513220694) terminated=False\n",
|
|
"step=297 reward=np.float64(0.9213441513220694) terminated=False\n",
|
|
"step=298 reward=np.float64(0.9213441513220694) terminated=False\n",
|
|
"step=299 reward=np.float64(0.9213441513220694) terminated=False\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"step = 0\n",
|
|
"done = False\n",
|
|
"\n",
|
|
"while not done:\n",
|
|
" # Prepare observation for the policy\n",
|
|
" state = torch.from_numpy(numpy_observation[\"agent_pos\"]) # Agent position\n",
|
|
" env_state = torch.from_numpy(numpy_observation[\"environment_state\"]) # Environment state\n",
|
|
"\n",
|
|
" # Convert to float32\n",
|
|
" state = state.to(torch.float32)\n",
|
|
" env_state = env_state.to(torch.float32)\n",
|
|
"\n",
|
|
" # Send data tensors from CPU to GPU\n",
|
|
" state = state.to(device, non_blocking=True)\n",
|
|
" env_state = env_state.to(device, non_blocking=True)\n",
|
|
"\n",
|
|
" # Add extra (empty) batch dimension, required to forward the policy\n",
|
|
" state = state.unsqueeze(0)\n",
|
|
" env_state = env_state.unsqueeze(0)\n",
|
|
"\n",
|
|
" # Create the policy input dictionary\n",
|
|
" observation = {\n",
|
|
" \"observation.state\": state,\n",
|
|
" \"observation.environment_state\": env_state, # Add environment_state here\n",
|
|
" }\n",
|
|
"\n",
|
|
" # Predict the next action with respect to the current observation\n",
|
|
" with torch.inference_mode():\n",
|
|
" action = policy.select_action(observation)\n",
|
|
"\n",
|
|
" # Prepare the action for the environment\n",
|
|
" numpy_action = action.squeeze(0).to(\"cpu\").numpy()\n",
|
|
"\n",
|
|
" # Step through the environment and receive a new observation\n",
|
|
" numpy_observation, reward, terminated, truncated, info = env.step(numpy_action)\n",
|
|
" print(f\"{step=} {reward=} {terminated=}\")\n",
|
|
"\n",
|
|
" # Keep track of all the rewards and frames\n",
|
|
" rewards.append(reward)\n",
|
|
" frames.append(env.render())\n",
|
|
"\n",
|
|
" # The rollout is considered done when the success state is reached (i.e. terminated is True),\n",
|
|
" # or the maximum number of iterations is reached (i.e. truncated is True)\n",
|
|
" done = terminated or truncated or done\n",
|
|
" step += 1\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 27,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Failure!\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (680, 680) to (688, 688) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to 1 (risking incompatibility).\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Video of the evaluation is available in '/home/lerobot/output/rollout.mp4'.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"if terminated:\n",
|
|
" print(\"Success!\")\n",
|
|
"else:\n",
|
|
" print(\"Failure!\")\n",
|
|
"\n",
|
|
"# Get the speed of environment (i.e. its number of frames per second).\n",
|
|
"fps = env.metadata[\"render_fps\"]\n",
|
|
"\n",
|
|
"# Encode all frames into a mp4 video.\n",
|
|
"video_path = \"/home/lerobot/output/rollout.mp4\"\n",
|
|
"imageio.mimsave(str(video_path), numpy.stack(frames), fps=fps)\n",
|
|
"\n",
|
|
"print(f\"Video of the evaluation is available in '{video_path}'.\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"#now on aloha\n",
|
|
"import imageio\n",
|
|
"import gymnasium as gym\n",
|
|
"import numpy as np\n",
|
|
"import gym_aloha\n",
|
|
"env = gym.make(\n",
|
|
" \"gym_aloha/AlohaInsertion-v0\",\n",
|
|
" obs_type=\"pixels\",\n",
|
|
" max_episode_steps=300,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/opt/conda/envs/lerobot/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
|
" from .autonotebook import tqdm as notebook_tqdm\n",
|
|
"Downloading: \"https://download.pytorch.org/models/resnet18-f37072fd.pth\" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth\n",
|
|
"100%|██████████| 44.7M/44.7M [00:00<00:00, 172MB/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from lerobot.common.policies.dot.modeling_dot import DOTPolicy\n",
|
|
"pretrained_policy_path = \"IliaLarchenko/dot_bimanual_insert\"\n",
|
|
"policy = DOTPolicy.from_pretrained(pretrained_policy_path, map_location=device)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"{'observation.images.top': PolicyFeature(type=<FeatureType.VISUAL: 'VISUAL'>, shape=(3, 480, 640)), 'observation.state': PolicyFeature(type=<FeatureType.STATE: 'STATE'>, shape=(14,))}\n",
|
|
"Dict('top': Box(0, 255, (480, 640, 3), uint8))\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# We can verify that the shapes of the features expected by the policy match the ones from the observations\n",
|
|
"# produced by the environment\n",
|
|
"print(policy.config.input_features)\n",
|
|
"print(env.observation_space)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"{'action': PolicyFeature(type=<FeatureType.ACTION: 'ACTION'>, shape=(14,))}\n",
|
|
"Box(-1.0, 1.0, (14,), float32)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Similarly, we can check that the actions produced by the policy will match the actions expected by the\n",
|
|
"# environment\n",
|
|
"print(policy.config.output_features)\n",
|
|
"print(env.action_space)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"ename": "FatalError",
|
|
"evalue": "gladLoadGL error",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[0;31mFatalError\u001b[0m Traceback (most recent call last)",
|
|
"Cell \u001b[0;32mIn[16], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Reset the policy and environments to prepare for rollout\u001b[39;00m\n\u001b[1;32m 2\u001b[0m policy\u001b[38;5;241m.\u001b[39mreset()\n\u001b[0;32m----> 3\u001b[0m numpy_observation, info \u001b[38;5;241m=\u001b[39m \u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m(\u001b[49m\u001b[43mseed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m42\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
|
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gymnasium/wrappers/time_limit.py:75\u001b[0m, in \u001b[0;36mTimeLimit.reset\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 66\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero.\u001b[39;00m\n\u001b[1;32m 67\u001b[0m \n\u001b[1;32m 68\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;124;03m The reset environment\u001b[39;00m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_elapsed_steps \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m---> 75\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
|
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gymnasium/wrappers/order_enforcing.py:61\u001b[0m, in \u001b[0;36mOrderEnforcing.reset\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Resets the environment with `kwargs`.\"\"\"\u001b[39;00m\n\u001b[1;32m 60\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_has_reset \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m---> 61\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
|
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gymnasium/wrappers/env_checker.py:57\u001b[0m, in \u001b[0;36mPassiveEnvChecker.reset\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchecked_reset:\n\u001b[1;32m 56\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchecked_reset \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m---> 57\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43menv_reset_passive_checker\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 58\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 59\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39menv\u001b[38;5;241m.\u001b[39mreset(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
|
|
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gymnasium/utils/passive_env_checker.py:186\u001b[0m, in \u001b[0;36menv_reset_passive_checker\u001b[0;34m(env, **kwargs)\u001b[0m\n\u001b[1;32m 181\u001b[0m logger\u001b[38;5;241m.\u001b[39mdeprecation(\n\u001b[1;32m 182\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCurrent gymnasium version requires that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 183\u001b[0m )\n\u001b[1;32m 185\u001b[0m \u001b[38;5;66;03m# Checks the result of env.reset with kwargs\u001b[39;00m\n\u001b[0;32m--> 186\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(result, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[1;32m 189\u001b[0m logger\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[1;32m 190\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(result)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m`\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 191\u001b[0m )\n",
|
|
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gym_aloha/env.py:166\u001b[0m, in \u001b[0;36mAlohaEnv.reset\u001b[0;34m(self, seed, options)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtask)\n\u001b[0;32m--> 166\u001b[0m raw_obs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_env\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m observation \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_raw_obs(raw_obs\u001b[38;5;241m.\u001b[39mobservation)\n\u001b[1;32m 170\u001b[0m info \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mis_success\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mFalse\u001b[39;00m}\n",
|
|
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/rl/control.py:89\u001b[0m, in \u001b[0;36mEnvironment.reset\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_physics\u001b[38;5;241m.\u001b[39mreset_context():\n\u001b[1;32m 87\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_task\u001b[38;5;241m.\u001b[39minitialize_episode(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_physics)\n\u001b[0;32m---> 89\u001b[0m observation \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_task\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_observation\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_physics\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 90\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_flat_observation:\n\u001b[1;32m 91\u001b[0m observation \u001b[38;5;241m=\u001b[39m flatten_observation(observation)\n",
|
|
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/gym_aloha/tasks/sim.py:92\u001b[0m, in \u001b[0;36mBimanualViperXTask.get_observation\u001b[0;34m(self, physics)\u001b[0m\n\u001b[1;32m 90\u001b[0m obs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124menv_state\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_env_state(physics)\n\u001b[1;32m 91\u001b[0m obs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimages\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m {}\n\u001b[0;32m---> 92\u001b[0m obs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimages\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtop\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[43mphysics\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrender\u001b[49m\u001b[43m(\u001b[49m\u001b[43mheight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m480\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwidth\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m640\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcamera_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtop\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 93\u001b[0m obs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimages\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mangle\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m physics\u001b[38;5;241m.\u001b[39mrender(height\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m480\u001b[39m, width\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m640\u001b[39m, camera_id\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mangle\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 94\u001b[0m obs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimages\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvis\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m physics\u001b[38;5;241m.\u001b[39mrender(height\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m480\u001b[39m, width\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m640\u001b[39m, camera_id\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfront_close\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
|
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/mujoco/engine.py:223\u001b[0m, in \u001b[0;36mPhysics.render\u001b[0;34m(self, height, width, camera_id, overlays, depth, segmentation, scene_option, render_flag_overrides, scene_callback)\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mrender\u001b[39m(\n\u001b[1;32m 179\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 180\u001b[0m height\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m240\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;28;01mNone\u001b[39;00m]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 190\u001b[0m ):\n\u001b[1;32m 191\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Returns a camera view as a NumPy array of pixel values.\u001b[39;00m\n\u001b[1;32m 192\u001b[0m \n\u001b[1;32m 193\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 221\u001b[0m \u001b[38;5;124;03m The rendered RGB, depth or segmentation image.\u001b[39;00m\n\u001b[1;32m 222\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 223\u001b[0m camera \u001b[38;5;241m=\u001b[39m \u001b[43mCamera\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 224\u001b[0m \u001b[43m \u001b[49m\u001b[43mphysics\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 225\u001b[0m \u001b[43m \u001b[49m\u001b[43mheight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 226\u001b[0m \u001b[43m \u001b[49m\u001b[43mwidth\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwidth\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 227\u001b[0m \u001b[43m \u001b[49m\u001b[43mcamera_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcamera_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 228\u001b[0m \u001b[43m \u001b[49m\u001b[43mscene_callback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mscene_callback\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 229\u001b[0m image \u001b[38;5;241m=\u001b[39m camera\u001b[38;5;241m.\u001b[39mrender(\n\u001b[1;32m 230\u001b[0m overlays\u001b[38;5;241m=\u001b[39moverlays, depth\u001b[38;5;241m=\u001b[39mdepth, segmentation\u001b[38;5;241m=\u001b[39msegmentation,\n\u001b[1;32m 231\u001b[0m scene_option\u001b[38;5;241m=\u001b[39mscene_option, render_flag_overrides\u001b[38;5;241m=\u001b[39mrender_flag_overrides)\n\u001b[1;32m 232\u001b[0m camera\u001b[38;5;241m.\u001b[39m_scene\u001b[38;5;241m.\u001b[39mfree() \u001b[38;5;66;03m# pylint: disable=protected-access\u001b[39;00m\n",
|
|
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/mujoco/engine.py:711\u001b[0m, in \u001b[0;36mCamera.__init__\u001b[0;34m(self, physics, height, width, camera_id, max_geom, scene_callback)\u001b[0m\n\u001b[1;32m 708\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_rgb_buffer \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mempty((\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_height, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_width, \u001b[38;5;241m3\u001b[39m), dtype\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39muint8)\n\u001b[1;32m 709\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_depth_buffer \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mempty((\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_height, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_width), dtype\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39mfloat32)\n\u001b[0;32m--> 711\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_physics\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcontexts\u001b[49m\u001b[38;5;241m.\u001b[39mmujoco \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 712\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_physics\u001b[38;5;241m.\u001b[39mcontexts\u001b[38;5;241m.\u001b[39mgl\u001b[38;5;241m.\u001b[39mmake_current() \u001b[38;5;28;01mas\u001b[39;00m ctx:\n\u001b[1;32m 713\u001b[0m ctx\u001b[38;5;241m.\u001b[39mcall(mujoco\u001b[38;5;241m.\u001b[39mmjr_setBuffer, mujoco\u001b[38;5;241m.\u001b[39mmjtFramebuffer\u001b[38;5;241m.\u001b[39mmjFB_OFFSCREEN,\n\u001b[1;32m 714\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_physics\u001b[38;5;241m.\u001b[39mcontexts\u001b[38;5;241m.\u001b[39mmujoco\u001b[38;5;241m.\u001b[39mptr)\n",
|
|
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/mujoco/engine.py:533\u001b[0m, in \u001b[0;36mPhysics.contexts\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 531\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_contexts_lock:\n\u001b[1;32m 532\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_contexts:\n\u001b[0;32m--> 533\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_make_rendering_contexts\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 534\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_contexts\n",
|
|
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/mujoco/engine.py:519\u001b[0m, in \u001b[0;36mPhysics._make_rendering_contexts\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 516\u001b[0m render_context \u001b[38;5;241m=\u001b[39m _render\u001b[38;5;241m.\u001b[39mRenderer(\n\u001b[1;32m 517\u001b[0m max_width\u001b[38;5;241m=\u001b[39mmax_width, max_height\u001b[38;5;241m=\u001b[39mmax_height)\n\u001b[1;32m 518\u001b[0m \u001b[38;5;66;03m# Create the MuJoCo context.\u001b[39;00m\n\u001b[0;32m--> 519\u001b[0m mujoco_context \u001b[38;5;241m=\u001b[39m \u001b[43mwrapper\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mMjrContext\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrender_context\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 520\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_contexts \u001b[38;5;241m=\u001b[39m Contexts(gl\u001b[38;5;241m=\u001b[39mrender_context, mujoco\u001b[38;5;241m=\u001b[39mmujoco_context)\n",
|
|
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/mujoco/wrapper/core.py:603\u001b[0m, in \u001b[0;36mMjrContext.__init__\u001b[0;34m(self, model, gl_context, font_scale)\u001b[0m\n\u001b[1;32m 601\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gl_context \u001b[38;5;241m=\u001b[39m gl_context\n\u001b[1;32m 602\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m gl_context\u001b[38;5;241m.\u001b[39mmake_current() \u001b[38;5;28;01mas\u001b[39;00m ctx:\n\u001b[0;32m--> 603\u001b[0m ptr \u001b[38;5;241m=\u001b[39m \u001b[43mctx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcall\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmujoco\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mMjrContext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mptr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfont_scale\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 604\u001b[0m ctx\u001b[38;5;241m.\u001b[39mcall(mujoco\u001b[38;5;241m.\u001b[39mmjr_setBuffer, mujoco\u001b[38;5;241m.\u001b[39mmjtFramebuffer\u001b[38;5;241m.\u001b[39mmjFB_OFFSCREEN, ptr)\n\u001b[1;32m 605\u001b[0m gl_context\u001b[38;5;241m.\u001b[39mkeep_alive(ptr)\n",
|
|
"File \u001b[0;32m/opt/conda/envs/lerobot/lib/python3.10/site-packages/dm_control/_render/executor/render_executor.py:138\u001b[0m, in \u001b[0;36mPassthroughRenderExecutor.call\u001b[0;34m(self, func, *args, **kwargs)\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mcall\u001b[39m(\u001b[38;5;28mself\u001b[39m, func, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 137\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_locked()\n\u001b[0;32m--> 138\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
|
"\u001b[0;31mFatalError\u001b[0m: gladLoadGL error"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Reset the policy and environments to prepare for rollout\n",
|
|
"policy.reset()\n",
|
|
"numpy_observation, info = env.reset(seed=42)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "lerobot",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.16"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|