Add Aloha env and ACT policy

WIP Aloha env tests pass

Rendering works (fps look fast tho? TODO action bounding is too wide [-1,1])

Update README

Copy past from act repo

Remove download.py add a WIP for Simxarm

Remove download.py add a WIP for Simxarm

Add act yaml (TODO: try train.py)

Training can runs (TODO: eval)

Add tasks without end_effector that are compatible with dataset, Eval can run (TODO: training and pretrained model)

Add AbstractEnv, Refactor AlohaEnv, Add rendering_hook in env, Minor modifications, (TODO: Refactor Pusht and Simxarm)

poetry lock

fix bug in compute_stats for action normalization

fix more bugs in normalization

fix training

fix import

PushtEnv inheriates AbstractEnv, Improve factory Normalization

Add _make_env to EnvAbstract

Add call_rendering_hooks to pusht env

SimxarmEnv inherites from AbstractEnv (NOT TESTED)

Add aloha tests artifacts + update pusht stats

fix image normalization: before env was in [0,1] but dataset in [0,255], and now both in [0,255]

Small fix on simxarm

Add next to obs

Add top camera to Aloha env (TODO: make it compatible with set of cameras)

Add top camera to Aloha env (TODO: make it compatible with set of cameras)
This commit is contained in:
Remi Cadene 2024-03-08 09:47:39 +00:00 committed by Cadene
parent 060bac7672
commit 9d002032d1
116 changed files with 3658 additions and 301 deletions

View File

@ -59,19 +59,10 @@ env=pusht
## TODO
- [x] priority update doesn't match FOWM or original paper
- [x] self.step=100000 should be updated at every step to adjust to the horizon of the planner
- [ ] prefetch replay buffer to speedup training
- [ ] parallelize env to speed up eval
- [ ] clean checkpointing / loading
- [ ] clean logging
- [ ] clean config
- [ ] clean hyperparameter tuning
- [ ] add pusht
- [ ] add aloha
- [ ] add act
- [ ] add diffusion
- [ ] add aloha 2
If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/users/Cadene/projects/1)
Ask [Remi Cadene](re.cadene@gmail.com) for access if needed.
## Profile

View File

@ -81,7 +81,10 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
def set_transform(self, transform):
if not isinstance(transform, Compose):
# required since torchrl calls `len(self._transform)` downstream
self._transform = Compose(transform)
if isinstance(transform, list):
self._transform = Compose(*transform)
else:
self._transform = Compose(transform)
else:
self._transform = transform

View File

@ -73,11 +73,11 @@ def download(data_dir, dataset_id):
data_dir.mkdir(parents=True, exist_ok=True)
gdown.download_folder(FOLDER_URLS[dataset_id], output=data_dir)
gdown.download_folder(FOLDER_URLS[dataset_id], output=str(data_dir))
# because of the 50 files limit per directory, two files episode 48 and 49 were missing
gdown.download(EP48_URLS[dataset_id], output=data_dir / "episode_48.hdf5", fuzzy=True)
gdown.download(EP49_URLS[dataset_id], output=data_dir / "episode_49.hdf5", fuzzy=True)
gdown.download(EP48_URLS[dataset_id], output=str(data_dir / "episode_48.hdf5"), fuzzy=True)
gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True)
class AlohaExperienceReplay(AbstractExperienceReplay):
@ -124,9 +124,6 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
def image_keys(self) -> list:
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]
# def _is_downloaded(self) -> bool:
# return False
def _download_and_preproc(self):
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
if not raw_dir.is_dir():

View File

@ -5,7 +5,7 @@ from pathlib import Path
import torch
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
from lerobot.common.envs.transforms import NormalizeTransform
from lerobot.common.envs.transforms import NormalizeTransform, Prod
DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
@ -84,6 +84,16 @@ def make_offline_buffer(
prefetch=prefetch if isinstance(prefetch, int) else None,
)
if cfg.policy.name == "tdmpc":
img_keys = []
for key in offline_buffer.image_keys:
img_keys.append(("next", *key))
img_keys += offline_buffer.image_keys
else:
img_keys = offline_buffer.image_keys
transforms = [Prod(in_keys=img_keys, prod=1 / 255)]
if normalize:
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
stats = offline_buffer.compute_or_load_stats()
@ -92,11 +102,10 @@ def make_offline_buffer(
in_keys = [("observation", "state"), ("action")]
if cfg.policy.name == "tdmpc":
for key in offline_buffer.image_keys:
# TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc
in_keys.append(key)
# since we use next observations in tdmpc
in_keys.append(("next", *key))
# TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now
in_keys += img_keys
# TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now.
in_keys += [("next", *key) for key in img_keys]
in_keys.append(("next", "observation", "state"))
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
@ -106,8 +115,11 @@ def make_offline_buffer(
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
transform = NormalizeTransform(stats, in_keys, mode="min_max")
offline_buffer.set_transform(transform)
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode))
offline_buffer.set_transform(transforms)
if not overwrite_sampler:
index = torch.arange(0, offline_buffer.num_samples, 1)

View File

@ -1,4 +1,5 @@
import pickle
import zipfile
from pathlib import Path
from typing import Callable
@ -15,6 +16,22 @@ from torchrl.data.replay_buffers.writers import Writer
from lerobot.common.datasets.abstract import AbstractExperienceReplay
def download():
raise NotImplementedError()
import gdown
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
download_path = "data.zip"
gdown.download(url, download_path, quiet=False)
print("Extracting...")
with zipfile.ZipFile(download_path, "r") as zip_f:
for member in zip_f.namelist():
if member.startswith("data/xarm") and member.endswith(".pkl"):
print(member)
zip_f.extract(member=member)
Path(download_path).unlink()
class SimxarmExperienceReplay(AbstractExperienceReplay):
available_datasets = [
"xarm_lift_medium",
@ -48,8 +65,8 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
)
def _download_and_preproc(self):
# download
# TODO(rcadene)
# TODO(rcadene): finish download
download()
dataset_path = self.data_dir / "buffer.pkl"
print(f"Using offline dataset '{dataset_path}'")

View File

@ -0,0 +1,80 @@
import abc
from collections import deque
from typing import Optional
from tensordict import TensorDict
from torchrl.envs import EnvBase
class AbstractEnv(EnvBase):
def __init__(
self,
task,
frame_skip: int = 1,
from_pixels: bool = False,
pixels_only: bool = False,
image_size=None,
seed=1337,
device="cpu",
num_prev_obs=1,
num_prev_action=0,
):
super().__init__(device=device, batch_size=[])
self.task = task
self.frame_skip = frame_skip
self.from_pixels = from_pixels
self.pixels_only = pixels_only
self.image_size = image_size
self.num_prev_obs = num_prev_obs
self.num_prev_action = num_prev_action
self._rendering_hooks = []
if pixels_only:
assert from_pixels
if from_pixels:
assert image_size
self._make_env()
self._make_spec()
self._current_seed = self.set_seed(seed)
if self.num_prev_obs > 0:
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
if self.num_prev_action > 0:
raise NotImplementedError()
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
def register_rendering_hook(self, func):
self._rendering_hooks.append(func)
def call_rendering_hooks(self):
for func in self._rendering_hooks:
func(self)
def reset_rendering_hooks(self):
self._rendering_hooks = []
@abc.abstractmethod
def render(self, mode="rgb_array", width=640, height=480):
raise NotImplementedError()
@abc.abstractmethod
def _reset(self, tensordict: Optional[TensorDict] = None):
raise NotImplementedError()
@abc.abstractmethod
def _step(self, tensordict: TensorDict):
raise NotImplementedError()
@abc.abstractmethod
def _make_env(self):
raise NotImplementedError()
@abc.abstractmethod
def _make_spec(self):
raise NotImplementedError()
@abc.abstractmethod
def _set_seed(self, seed: Optional[int]):
raise NotImplementedError()

View File

@ -0,0 +1,59 @@
<mujoco>
<include file="scene.xml"/>
<include file="vx300s_dependencies.xml"/>
<equality>
<weld body1="mocap_left" body2="vx300s_left/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
<weld body1="mocap_right" body2="vx300s_right/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
</equality>
<worldbody>
<include file="vx300s_left.xml" />
<include file="vx300s_right.xml" />
<body mocap="true" name="mocap_left" pos="0.095 0.50 0.425">
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_left_site1" rgba="1 0 0 1"/>
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_left_site2" rgba="1 0 0 1"/>
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_left_site3" rgba="1 0 0 1"/>
</body>
<body mocap="true" name="mocap_right" pos="-0.095 0.50 0.425">
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_right_site1" rgba="1 0 0 1"/>
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_right_site2" rgba="1 0 0 1"/>
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_right_site3" rgba="1 0 0 1"/>
</body>
<body name="peg" pos="0.2 0.5 0.05">
<joint name="red_peg_joint" type="free" frictionloss="0.01" />
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg" rgba="1 0 0 1" />
</body>
<body name="socket" pos="-0.2 0.5 0.05">
<joint name="blue_socket_joint" type="free" frictionloss="0.01" />
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
<!-- <geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg_ref" rgba="1 0 0 1" />-->
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 -0.02" size="0.06 0.018 0.002" type="box" name="socket-1" rgba="0 0 1 1" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 0.02" size="0.06 0.018 0.002" type="box" name="socket-2" rgba="0 0 1 1" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0.02 0" size="0.06 0.002 0.018" type="box" name="socket-3" rgba="0 0 1 1" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 -0.02 0" size="0.06 0.002 0.018" type="box" name="socket-4" rgba="0 0 1 1" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.04 0.01 0.01" type="box" name="pin" rgba="1 0 0 1" />
</body>
</worldbody>
<actuator>
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
</actuator>
<keyframe>
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0 -0.2 0.5 0.05 1 0 0 0"/>
</keyframe>
</mujoco>

View File

@ -0,0 +1,48 @@
<mujoco>
<include file="scene.xml"/>
<include file="vx300s_dependencies.xml"/>
<equality>
<weld body1="mocap_left" body2="vx300s_left/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
<weld body1="mocap_right" body2="vx300s_right/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
</equality>
<worldbody>
<include file="vx300s_left.xml" />
<include file="vx300s_right.xml" />
<body mocap="true" name="mocap_left" pos="0.095 0.50 0.425">
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_left_site1" rgba="1 0 0 1"/>
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_left_site2" rgba="1 0 0 1"/>
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_left_site3" rgba="1 0 0 1"/>
</body>
<body mocap="true" name="mocap_right" pos="-0.095 0.50 0.425">
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_right_site1" rgba="1 0 0 1"/>
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_right_site2" rgba="1 0 0 1"/>
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_right_site3" rgba="1 0 0 1"/>
</body>
<body name="box" pos="0.2 0.5 0.05">
<joint name="red_box_joint" type="free" frictionloss="0.01" />
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.02 0.02 0.02" type="box" name="red_box" rgba="1 0 0 1" />
</body>
</worldbody>
<actuator>
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
</actuator>
<keyframe>
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0"/>
</keyframe>
</mujoco>

View File

@ -0,0 +1,53 @@
<mujoco>
<include file="scene.xml"/>
<include file="vx300s_dependencies.xml"/>
<worldbody>
<include file="vx300s_left.xml" />
<include file="vx300s_right.xml" />
<body name="peg" pos="0.2 0.5 0.05">
<joint name="red_peg_joint" type="free" frictionloss="0.01" />
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg" rgba="1 0 0 1" />
</body>
<body name="socket" pos="-0.2 0.5 0.05">
<joint name="blue_socket_joint" type="free" frictionloss="0.01" />
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
<!-- <geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg_ref" rgba="1 0 0 1" />-->
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 -0.02" size="0.06 0.018 0.002" type="box" name="socket-1" rgba="0 0 1 1" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 0.02" size="0.06 0.018 0.002" type="box" name="socket-2" rgba="0 0 1 1" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0.02 0" size="0.06 0.002 0.018" type="box" name="socket-3" rgba="0 0 1 1" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 -0.02 0" size="0.06 0.002 0.018" type="box" name="socket-4" rgba="0 0 1 1" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.04 0.01 0.01" type="box" name="pin" rgba="1 0 0 1" />
</body>
</worldbody>
<actuator>
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_left/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_left/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_left/wrist_angle" kp="50" user="1"/>
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/wrist_rotate" kp="20" user="1"/>
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_right/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_right/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_right/wrist_angle" kp="50" user="1"/>
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/wrist_rotate" kp="20" user="1"/>
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
</actuator>
<keyframe>
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0 -0.2 0.5 0.05 1 0 0 0"/>
</keyframe>
</mujoco>

View File

@ -0,0 +1,42 @@
<mujoco>
<include file="scene.xml"/>
<include file="vx300s_dependencies.xml"/>
<worldbody>
<include file="vx300s_left.xml" />
<include file="vx300s_right.xml" />
<body name="box" pos="0.2 0.5 0.05">
<joint name="red_box_joint" type="free" frictionloss="0.01" />
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.02 0.02 0.02" type="box" name="red_box" rgba="1 0 0 1" />
</body>
</worldbody>
<actuator>
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_left/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_left/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_left/wrist_angle" kp="50" user="1"/>
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/wrist_rotate" kp="20" user="1"/>
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_right/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_right/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_right/wrist_angle" kp="50" user="1"/>
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/wrist_rotate" kp="20" user="1"/>
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
</actuator>
<keyframe>
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0"/>
</keyframe>
</mujoco>

View File

@ -0,0 +1,38 @@
<mujocoinclude>
<!-- <option timestep='0.0025' iterations="50" tolerance="1e-10" solver="Newton" jacobian="dense" cone="elliptic"/>-->
<asset>
<mesh file="tabletop.stl" name="tabletop" scale="0.001 0.001 0.001"/>
</asset>
<visual>
<map fogstart="1.5" fogend="5" force="0.1" znear="0.1"/>
<quality shadowsize="4096" offsamples="4"/>
<headlight ambient="0.4 0.4 0.4"/>
</visual>
<worldbody>
<light castshadow="false" directional='true' diffuse='.3 .3 .3' specular='0.3 0.3 0.3' pos='-1 -1 1'
dir='1 1 -1'/>
<light directional='true' diffuse='.3 .3 .3' specular='0.3 0.3 0.3' pos='1 -1 1' dir='-1 1 -1'/>
<light castshadow="false" directional='true' diffuse='.3 .3 .3' specular='0.3 0.3 0.3' pos='0 1 1'
dir='0 -1 -1'/>
<body name="table" pos="0 .6 0">
<geom group="1" mesh="tabletop" pos="0 0 0" type="mesh" conaffinity="1" contype="1" name="table" rgba="0.2 0.2 0.2 1" />
</body>
<body name="midair" pos="0 .6 0.2">
<site pos="0 0 0" size="0.01" type="sphere" name="midair" rgba="1 0 0 0"/>
</body>
<camera name="left_pillar" pos="-0.5 0.2 0.6" fovy="78" mode="targetbody" target="table"/>
<camera name="right_pillar" pos="0.5 0.2 0.6" fovy="78" mode="targetbody" target="table"/>
<camera name="top" pos="0 0.6 0.8" fovy="78" mode="targetbody" target="table"/>
<camera name="angle" pos="0 0 0.6" fovy="78" mode="targetbody" target="table"/>
<camera name="front_close" pos="0 0.2 0.4" fovy="78" mode="targetbody" target="vx300s_left/camera_focus"/>
</worldbody>
</mujocoinclude>

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,17 @@
<mujocoinclude>
<compiler angle="radian" inertiafromgeom="auto" inertiagrouprange="4 5"/>
<asset>
<mesh name="vx300s_1_base" file="vx300s_1_base.stl" scale="0.001 0.001 0.001" />
<mesh name="vx300s_2_shoulder" file="vx300s_2_shoulder.stl" scale="0.001 0.001 0.001" />
<mesh name="vx300s_3_upper_arm" file="vx300s_3_upper_arm.stl" scale="0.001 0.001 0.001" />
<mesh name="vx300s_4_upper_forearm" file="vx300s_4_upper_forearm.stl" scale="0.001 0.001 0.001" />
<mesh name="vx300s_5_lower_forearm" file="vx300s_5_lower_forearm.stl" scale="0.001 0.001 0.001" />
<mesh name="vx300s_6_wrist" file="vx300s_6_wrist.stl" scale="0.001 0.001 0.001" />
<mesh name="vx300s_7_gripper" file="vx300s_7_gripper.stl" scale="0.001 0.001 0.001" />
<mesh name="vx300s_8_gripper_prop" file="vx300s_8_gripper_prop.stl" scale="0.001 0.001 0.001" />
<mesh name="vx300s_9_gripper_bar" file="vx300s_9_gripper_bar.stl" scale="0.001 0.001 0.001" />
<mesh name="vx300s_10_gripper_finger_left" file="vx300s_10_custom_finger_left.stl" scale="0.001 0.001 0.001" />
<mesh name="vx300s_10_gripper_finger_right" file="vx300s_10_custom_finger_right.stl" scale="0.001 0.001 0.001" />
</asset>
</mujocoinclude>

View File

@ -0,0 +1,59 @@
<mujocoinclude>
<body name="vx300s_left" pos="-0.469 0.5 0">
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_1_base" name="vx300s_left/1_base" contype="0" conaffinity="0"/>
<body name="vx300s_left/shoulder_link" pos="0 0 0.079">
<inertial pos="0.000259233 -3.3552e-06 0.0116129" quat="-0.476119 0.476083 0.52279 0.522826" mass="0.798614" diaginertia="0.00120156 0.00113744 0.0009388" />
<joint name="vx300s_left/waist" pos="0 0 0" axis="0 0 1" limited="true" range="-3.14158 3.14158" frictionloss="50" />
<geom pos="0 0 -0.003" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_2_shoulder" name="vx300s_left/2_shoulder" />
<body name="vx300s_left/upper_arm_link" pos="0 0 0.04805">
<inertial pos="0.0206949 4e-10 0.226459" quat="0 0.0728458 0 0.997343" mass="0.792592" diaginertia="0.00911338 0.008925 0.000759317" />
<joint name="vx300s_left/shoulder" pos="0 0 0" axis="0 1 0" limited="true" range="-1.85005 1.25664" frictionloss="60" />
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_3_upper_arm" name="vx300s_left/3_upper_arm"/>
<body name="vx300s_left/upper_forearm_link" pos="0.05955 0 0.3">
<inertial pos="0.105723 0 0" quat="-0.000621631 0.704724 0.0105292 0.709403" mass="0.322228" diaginertia="0.00144107 0.00134228 0.000152047" />
<joint name="vx300s_left/elbow" pos="0 0 0" axis="0 1 0" limited="true" range="-1.76278 1.6057" frictionloss="60" />
<geom type="mesh" mesh="vx300s_4_upper_forearm" name="vx300s_left/4_upper_forearm" />
<body name="vx300s_left/lower_forearm_link" pos="0.2 0 0">
<inertial pos="0.0513477 0.00680462 0" quat="-0.702604 -0.0796724 -0.702604 0.0796724" mass="0.414823" diaginertia="0.0005911 0.000546493 0.000155707" />
<joint name="vx300s_left/forearm_roll" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
<geom quat="0 1 0 0" type="mesh" mesh="vx300s_5_lower_forearm" name="vx300s_left/5_lower_forearm"/>
<body name="vx300s_left/wrist_link" pos="0.1 0 0">
<inertial pos="0.046743 -7.6652e-06 0.010565" quat="-0.00100191 0.544586 0.0026583 0.8387" mass="0.115395" diaginertia="5.45707e-05 4.63101e-05 4.32692e-05" />
<joint name="vx300s_left/wrist_angle" pos="0 0 0" axis="0 1 0" limited="true" range="-1.8675 2.23402" frictionloss="30" />
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_6_wrist" name="vx300s_left/6_wrist" />
<body name="vx300s_left/gripper_link" pos="0.069744 0 0">
<body name="vx300s_left/camera_focus" pos="0.15 0 0.01">
<site pos="0 0 0" size="0.01" type="sphere" name="left_cam_focus" rgba="0 0 1 0"/>
</body>
<site pos="0.15 0 0" size="0.003 0.003 0.03" type="box" name="cali_left_site1" rgba="0 0 1 0"/>
<site pos="0.15 0 0" size="0.003 0.03 0.003" type="box" name="cali_left_site2" rgba="0 0 1 0"/>
<site pos="0.15 0 0" size="0.03 0.003 0.003" type="box" name="cali_left_site3" rgba="0 0 1 0"/>
<camera name="left_wrist" pos="-0.1 0 0.16" fovy="20" mode="targetbody" target="vx300s_left/camera_focus"/>
<inertial pos="0.0395662 -2.56311e-07 0.00400649" quat="0.62033 0.619916 -0.339682 0.339869" mass="0.251652" diaginertia="0.000689546 0.000650316 0.000468142" />
<joint name="vx300s_left/wrist_rotate" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
<geom pos="-0.02 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_7_gripper" name="vx300s_left/7_gripper" />
<geom pos="-0.020175 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_9_gripper_bar" name="vx300s_left/9_gripper_bar" />
<body name="vx300s_left/gripper_prop_link" pos="0.0485 0 0">
<inertial pos="0.002378 2.85e-08 0" quat="0 0 0.897698 0.440611" mass="0.008009" diaginertia="4.2979e-06 2.8868e-06 1.5314e-06" />
<!-- <joint name="vx300s_left/gripper" pos="0 0 0" axis="1 0 0" frictionloss="30" />-->
<geom pos="-0.0685 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_8_gripper_prop" name="vx300s_left/8_gripper_prop" />
</body>
<body name="vx300s_left/left_finger_link" pos="0.0687 0 0">
<inertial pos="0.017344 -0.0060692 0" quat="0.449364 0.449364 -0.54596 -0.54596" mass="0.034796" diaginertia="2.48003e-05 1.417e-05 1.20797e-05" />
<joint name="vx300s_left/left_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="0.021 0.057" frictionloss="30" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 -0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_left" name="vx300s_left/10_left_gripper_finger"/>
</body>
<body name="vx300s_left/right_finger_link" pos="0.0687 0 0">
<inertial pos="0.017344 0.0060692 0" quat="0.44937 -0.44937 0.545955 -0.545955" mass="0.034796" diaginertia="2.48002e-05 1.417e-05 1.20798e-05" />
<joint name="vx300s_left/right_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="-0.057 -0.021" frictionloss="30" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_right" name="vx300s_left/10_right_gripper_finger"/>
</body>
</body>
</body>
</body>
</body>
</body>
</body>
</body>
</mujocoinclude>

View File

@ -0,0 +1,59 @@
<mujocoinclude>
<body name="vx300s_right" pos="0.469 0.5 0" euler="0 0 3.1416">
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_1_base" name="vx300s_right/1_base" contype="0" conaffinity="0"/>
<body name="vx300s_right/shoulder_link" pos="0 0 0.079">
<inertial pos="0.000259233 -3.3552e-06 0.0116129" quat="-0.476119 0.476083 0.52279 0.522826" mass="0.798614" diaginertia="0.00120156 0.00113744 0.0009388" />
<joint name="vx300s_right/waist" pos="0 0 0" axis="0 0 1" limited="true" range="-3.14158 3.14158" frictionloss="50" />
<geom pos="0 0 -0.003" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_2_shoulder" name="vx300s_right/2_shoulder" />
<body name="vx300s_right/upper_arm_link" pos="0 0 0.04805">
<inertial pos="0.0206949 4e-10 0.226459" quat="0 0.0728458 0 0.997343" mass="0.792592" diaginertia="0.00911338 0.008925 0.000759317" />
<joint name="vx300s_right/shoulder" pos="0 0 0" axis="0 1 0" limited="true" range="-1.85005 1.25664" frictionloss="60" />
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_3_upper_arm" name="vx300s_right/3_upper_arm"/>
<body name="vx300s_right/upper_forearm_link" pos="0.05955 0 0.3">
<inertial pos="0.105723 0 0" quat="-0.000621631 0.704724 0.0105292 0.709403" mass="0.322228" diaginertia="0.00144107 0.00134228 0.000152047" />
<joint name="vx300s_right/elbow" pos="0 0 0" axis="0 1 0" limited="true" range="-1.76278 1.6057" frictionloss="60" />
<geom type="mesh" mesh="vx300s_4_upper_forearm" name="vx300s_right/4_upper_forearm" />
<body name="vx300s_right/lower_forearm_link" pos="0.2 0 0">
<inertial pos="0.0513477 0.00680462 0" quat="-0.702604 -0.0796724 -0.702604 0.0796724" mass="0.414823" diaginertia="0.0005911 0.000546493 0.000155707" />
<joint name="vx300s_right/forearm_roll" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
<geom quat="0 1 0 0" type="mesh" mesh="vx300s_5_lower_forearm" name="vx300s_right/5_lower_forearm"/>
<body name="vx300s_right/wrist_link" pos="0.1 0 0">
<inertial pos="0.046743 -7.6652e-06 0.010565" quat="-0.00100191 0.544586 0.0026583 0.8387" mass="0.115395" diaginertia="5.45707e-05 4.63101e-05 4.32692e-05" />
<joint name="vx300s_right/wrist_angle" pos="0 0 0" axis="0 1 0" limited="true" range="-1.8675 2.23402" frictionloss="30" />
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_6_wrist" name="vx300s_right/6_wrist" />
<body name="vx300s_right/gripper_link" pos="0.069744 0 0">
<body name="vx300s_right/camera_focus" pos="0.15 0 0.01">
<site pos="0 0 0" size="0.01" type="sphere" name="right_cam_focus" rgba="0 0 1 0"/>
</body>
<site pos="0.15 0 0" size="0.003 0.003 0.03" type="box" name="cali_right_site1" rgba="0 0 1 0"/>
<site pos="0.15 0 0" size="0.003 0.03 0.003" type="box" name="cali_right_site2" rgba="0 0 1 0"/>
<site pos="0.15 0 0" size="0.03 0.003 0.003" type="box" name="cali_right_site3" rgba="0 0 1 0"/>
<camera name="right_wrist" pos="-0.1 0 0.16" fovy="20" mode="targetbody" target="vx300s_right/camera_focus"/>
<inertial pos="0.0395662 -2.56311e-07 0.00400649" quat="0.62033 0.619916 -0.339682 0.339869" mass="0.251652" diaginertia="0.000689546 0.000650316 0.000468142" />
<joint name="vx300s_right/wrist_rotate" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
<geom pos="-0.02 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_7_gripper" name="vx300s_right/7_gripper" />
<geom pos="-0.020175 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_9_gripper_bar" name="vx300s_right/9_gripper_bar" />
<body name="vx300s_right/gripper_prop_link" pos="0.0485 0 0">
<inertial pos="0.002378 2.85e-08 0" quat="0 0 0.897698 0.440611" mass="0.008009" diaginertia="4.2979e-06 2.8868e-06 1.5314e-06" />
<!-- <joint name="vx300s_right/gripper" pos="0 0 0" axis="1 0 0" frictionloss="30" />-->
<geom pos="-0.0685 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_8_gripper_prop" name="vx300s_right/8_gripper_prop" />
</body>
<body name="vx300s_right/left_finger_link" pos="0.0687 0 0">
<inertial pos="0.017344 -0.0060692 0" quat="0.449364 0.449364 -0.54596 -0.54596" mass="0.034796" diaginertia="2.48003e-05 1.417e-05 1.20797e-05" />
<joint name="vx300s_right/left_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="0.021 0.057" frictionloss="30" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 -0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_left" name="vx300s_right/10_left_gripper_finger"/>
</body>
<body name="vx300s_right/right_finger_link" pos="0.0687 0 0">
<inertial pos="0.017344 0.0060692 0" quat="0.44937 -0.44937 0.545955 -0.545955" mass="0.034796" diaginertia="2.48002e-05 1.417e-05 1.20798e-05" />
<joint name="vx300s_right/right_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="-0.057 -0.021" frictionloss="30" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_right" name="vx300s_right/10_right_gripper_finger"/>
</body>
</body>
</body>
</body>
</body>
</body>
</body>
</body>
</mujocoinclude>

View File

@ -0,0 +1,163 @@
from pathlib import Path
### Simulation envs fixed constants
DT = 0.02 # 0.02 ms -> 1/0.2 = 50 hz
FPS = 50
JOINTS = [
# absolute joint position
"left_arm_waist",
"left_arm_shoulder",
"left_arm_elbow",
"left_arm_forearm_roll",
"left_arm_wrist_angle",
"left_arm_wrist_rotate",
# normalized gripper position 0: close, 1: open
"left_arm_gripper",
# absolute joint position
"right_arm_waist",
"right_arm_shoulder",
"right_arm_elbow",
"right_arm_forearm_roll",
"right_arm_wrist_angle",
"right_arm_wrist_rotate",
# normalized gripper position 0: close, 1: open
"right_arm_gripper",
]
ACTIONS = [
# position and quaternion for end effector
"left_arm_waist",
"left_arm_shoulder",
"left_arm_elbow",
"left_arm_forearm_roll",
"left_arm_wrist_angle",
"left_arm_wrist_rotate",
# normalized gripper position (0: close, 1: open)
"left_arm_gripper",
"right_arm_waist",
"right_arm_shoulder",
"right_arm_elbow",
"right_arm_forearm_roll",
"right_arm_wrist_angle",
"right_arm_wrist_rotate",
# normalized gripper position (0: close, 1: open)
"right_arm_gripper",
]
START_ARM_POSE = [
0,
-0.96,
1.16,
0,
-0.3,
0,
0.02239,
-0.02239,
0,
-0.96,
1.16,
0,
-0.3,
0,
0.02239,
-0.02239,
]
ASSETS_DIR = Path(__file__).parent.resolve() / "assets" # note: absolute path
# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
MASTER_GRIPPER_POSITION_OPEN = 0.02417
MASTER_GRIPPER_POSITION_CLOSE = 0.01244
PUPPET_GRIPPER_POSITION_OPEN = 0.05800
PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
# Gripper joint limits (qpos[6])
MASTER_GRIPPER_JOINT_OPEN = 0.3083
MASTER_GRIPPER_JOINT_CLOSE = -0.6842
PUPPET_GRIPPER_JOINT_OPEN = 1.4910
PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2
############################ Helper functions ############################
def normalize_master_gripper_position(x):
return (x - MASTER_GRIPPER_POSITION_CLOSE) / (
MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
)
def normalize_puppet_gripper_position(x):
return (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
)
def unnormalize_master_gripper_position(x):
return x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
def unnormalize_puppet_gripper_position(x):
return x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
def convert_position_from_master_to_puppet(x):
return unnormalize_puppet_gripper_position(normalize_master_gripper_position(x))
def normalizer_master_gripper_joint(x):
return (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
def normalize_puppet_gripper_joint(x):
return (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
def unnormalize_master_gripper_joint(x):
return x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
def unnormalize_puppet_gripper_joint(x):
return x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
def convert_join_from_master_to_puppet(x):
return unnormalize_puppet_gripper_joint(normalizer_master_gripper_joint(x))
def normalize_master_gripper_velocity(x):
return x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
def normalize_puppet_gripper_velocity(x):
return x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
def convert_master_from_position_to_joint(x):
return (
normalize_master_gripper_position(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
+ MASTER_GRIPPER_JOINT_CLOSE
)
def convert_master_from_joint_to_position(x):
return unnormalize_master_gripper_position(
(x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
)
def convert_puppet_from_position_to_join(x):
return (
normalize_puppet_gripper_position(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
+ PUPPET_GRIPPER_JOINT_CLOSE
)
def convert_puppet_from_joint_to_position(x):
return unnormalize_puppet_gripper_position(
(x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
)

View File

@ -0,0 +1,311 @@
import importlib
import logging
from collections import deque
from typing import Optional
import einops
import numpy as np
import torch
from dm_control import mujoco
from dm_control.rl import control
from tensordict import TensorDict
from torchrl.data.tensor_specs import (
BoundedTensorSpec,
CompositeSpec,
DiscreteTensorSpec,
UnboundedContinuousTensorSpec,
)
from lerobot.common.envs.abstract import AbstractEnv
from lerobot.common.envs.aloha.constants import (
ACTIONS,
ASSETS_DIR,
DT,
JOINTS,
)
from lerobot.common.envs.aloha.tasks.sim import BOX_POSE, InsertionTask, TransferCubeTask
from lerobot.common.envs.aloha.tasks.sim_end_effector import (
InsertionEndEffectorTask,
TransferCubeEndEffectorTask,
)
from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
from lerobot.common.utils import set_seed
_has_gym = importlib.util.find_spec("gym") is not None
class AlohaEnv(AbstractEnv):
def __init__(
self,
task,
frame_skip: int = 1,
from_pixels: bool = False,
pixels_only: bool = False,
image_size=None,
seed=1337,
device="cpu",
num_prev_obs=1,
num_prev_action=0,
):
super().__init__(
task=task,
frame_skip=frame_skip,
from_pixels=from_pixels,
pixels_only=pixels_only,
image_size=image_size,
seed=seed,
device=device,
num_prev_obs=num_prev_obs,
num_prev_action=num_prev_action,
)
def _make_env(self):
if not _has_gym:
raise ImportError("Cannot import gym.")
if not self.from_pixels:
raise NotImplementedError()
self._env = self._make_env_task(self.task)
def render(self, mode="rgb_array", width=640, height=480):
# TODO(rcadene): render and visualizer several cameras (e.g. angle, front_close)
image = self._env.physics.render(height=height, width=width, camera_id="top")
return image
def _make_env_task(self, task_name):
# time limit is controlled by StepCounter in env factory
time_limit = float("inf")
if "sim_transfer_cube" in task_name:
xml_path = ASSETS_DIR / "bimanual_viperx_transfer_cube.xml"
physics = mujoco.Physics.from_xml_path(str(xml_path))
task = TransferCubeTask(random=False)
elif "sim_insertion" in task_name:
xml_path = ASSETS_DIR / "bimanual_viperx_insertion.xml"
physics = mujoco.Physics.from_xml_path(str(xml_path))
task = InsertionTask(random=False)
elif "sim_end_effector_transfer_cube" in task_name:
raise NotImplementedError()
xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_transfer_cube.xml"
physics = mujoco.Physics.from_xml_path(str(xml_path))
task = TransferCubeEndEffectorTask(random=False)
elif "sim_end_effector_insertion" in task_name:
raise NotImplementedError()
xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_insertion.xml"
physics = mujoco.Physics.from_xml_path(str(xml_path))
task = InsertionEndEffectorTask(random=False)
else:
raise NotImplementedError(task_name)
env = control.Environment(
physics, task, time_limit, control_timestep=DT, n_sub_steps=None, flat_observation=False
)
return env
def _format_raw_obs(self, raw_obs):
if self.from_pixels:
image = torch.from_numpy(raw_obs["images"]["top"].copy())
image = einops.rearrange(image, "h w c -> c h w")
assert image.dtype == torch.uint8
obs = {"image": {"top": image}}
if not self.pixels_only:
obs["state"] = torch.from_numpy(raw_obs["qpos"]).type(torch.float32)
else:
# TODO(rcadene):
raise NotImplementedError()
# obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)}
return obs
def _reset(self, tensordict: Optional[TensorDict] = None):
td = tensordict
if td is None or td.is_empty():
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1
self.set_seed(self._current_seed)
# TODO(rcadene): do not use global variable for this
if "sim_transfer_cube" in self.task:
BOX_POSE[0] = sample_box_pose() # used in sim reset
elif "sim_insertion" in self.task:
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
raw_obs = self._env.reset()
# TODO(rcadene): add assert
# assert self._current_seed == self._env._seed
obs = self._format_raw_obs(raw_obs.observation)
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue = deque(
[obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
if "state" in obs:
self._prev_obs_state_queue = deque(
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"done": torch.tensor([False], dtype=torch.bool),
},
batch_size=[],
)
else:
raise NotImplementedError()
self.call_rendering_hooks()
return td
def _step(self, tensordict: TensorDict):
td = tensordict
action = td["action"].numpy()
# step expects shape=(4,) so we pad if necessary
# TODO(rcadene): add info["is_success"] and info["success"] ?
sum_reward = 0
if action.ndim == 1:
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
else:
if self.frame_skip > 1:
raise NotImplementedError()
num_action_steps = action.shape[0]
for i in range(num_action_steps):
_, reward, discount, raw_obs = self._env.step(action[i])
del discount # not used
# TOOD(rcadene): add an enum
success = done = reward == 4
sum_reward += reward
obs = self._format_raw_obs(raw_obs)
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"]["top"])
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
self.call_rendering_hooks()
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"reward": torch.tensor([sum_reward], dtype=torch.float32),
# succes and done are true when coverage > self.success_threshold in env
"done": torch.tensor([done], dtype=torch.bool),
"success": torch.tensor([success], dtype=torch.bool),
},
batch_size=[],
)
return td
def _make_spec(self):
obs = {}
from omegaconf import OmegaConf
if self.from_pixels:
if isinstance(self.image_size, int):
image_shape = (3, self.image_size, self.image_size)
elif OmegaConf.is_list(self.image_size):
assert len(self.image_size) == 3 # c h w
assert self.image_size[0] == 3 # c is RGB
image_shape = tuple(self.image_size)
else:
raise ValueError(self.image_size)
if self.num_prev_obs > 0:
image_shape = (self.num_prev_obs + 1, *image_shape)
obs["image"] = {
"top": BoundedTensorSpec(
low=0,
high=255,
shape=image_shape,
dtype=torch.uint8,
device=self.device,
)
}
if not self.pixels_only:
state_shape = (len(JOINTS),)
if self.num_prev_obs > 0:
state_shape = (self.num_prev_obs + 1, *state_shape)
obs["state"] = UnboundedContinuousTensorSpec(
# TODO: add low and high bounds
shape=state_shape,
dtype=torch.float32,
device=self.device,
)
else:
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
state_shape = (len(JOINTS),)
if self.num_prev_obs > 0:
state_shape = (self.num_prev_obs + 1, *state_shape)
obs["state"] = UnboundedContinuousTensorSpec(
# TODO: add low and high bounds
shape=state_shape,
dtype=torch.float32,
device=self.device,
)
self.observation_spec = CompositeSpec({"observation": obs})
# TODO(rcadene): valid when controling end effector?
# action_space = self._env.action_spec()
# self.action_spec = BoundedTensorSpec(
# low=action_space.minimum,
# high=action_space.maximum,
# shape=action_space.shape,
# dtype=torch.float32,
# device=self.device,
# )
# TODO(rcaene): add bounds (where are they????)
self.action_spec = BoundedTensorSpec(
shape=(len(ACTIONS)),
low=-1,
high=1,
dtype=torch.float32,
device=self.device,
)
self.reward_spec = UnboundedContinuousTensorSpec(
shape=(1,),
dtype=torch.float32,
device=self.device,
)
self.done_spec = CompositeSpec(
{
"done": DiscreteTensorSpec(
2,
shape=(1,),
dtype=torch.bool,
device=self.device,
),
"success": DiscreteTensorSpec(
2,
shape=(1,),
dtype=torch.bool,
device=self.device,
),
}
)
def _set_seed(self, seed: Optional[int]):
set_seed(seed)
# TODO(rcadene): seed the env
# self._env.seed(seed)
logging.warning("Aloha env is not seeded")

View File

@ -0,0 +1,219 @@
import collections
import numpy as np
from dm_control.suite import base
from lerobot.common.envs.aloha.constants import (
START_ARM_POSE,
normalize_puppet_gripper_position,
normalize_puppet_gripper_velocity,
unnormalize_puppet_gripper_position,
)
BOX_POSE = [None] # to be changed from outside
"""
Environment for simulated robot bi-manual manipulation, with joint position control
Action space: [left_arm_qpos (6), # absolute joint position
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
right_arm_qpos (6), # absolute joint position
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
right_arm_qpos (6), # absolute joint position
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
right_arm_qvel (6), # absolute joint velocity (rad)
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
"""
class BimanualViperXTask(base.Task):
def __init__(self, random=None):
super().__init__(random=random)
def before_step(self, action, physics):
left_arm_action = action[:6]
right_arm_action = action[7 : 7 + 6]
normalized_left_gripper_action = action[6]
normalized_right_gripper_action = action[7 + 6]
left_gripper_action = unnormalize_puppet_gripper_position(normalized_left_gripper_action)
right_gripper_action = unnormalize_puppet_gripper_position(normalized_right_gripper_action)
full_left_gripper_action = [left_gripper_action, -left_gripper_action]
full_right_gripper_action = [right_gripper_action, -right_gripper_action]
env_action = np.concatenate(
[left_arm_action, full_left_gripper_action, right_arm_action, full_right_gripper_action]
)
super().before_step(env_action, physics)
return
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
super().initialize_episode(physics)
@staticmethod
def get_qpos(physics):
qpos_raw = physics.data.qpos.copy()
left_qpos_raw = qpos_raw[:8]
right_qpos_raw = qpos_raw[8:16]
left_arm_qpos = left_qpos_raw[:6]
right_arm_qpos = right_qpos_raw[:6]
left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])]
right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])]
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
@staticmethod
def get_qvel(physics):
qvel_raw = physics.data.qvel.copy()
left_qvel_raw = qvel_raw[:8]
right_qvel_raw = qvel_raw[8:16]
left_arm_qvel = left_qvel_raw[:6]
right_arm_qvel = right_qvel_raw[:6]
left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])]
right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])]
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
@staticmethod
def get_env_state(physics):
raise NotImplementedError
def get_observation(self, physics):
obs = collections.OrderedDict()
obs["qpos"] = self.get_qpos(physics)
obs["qvel"] = self.get_qvel(physics)
obs["env_state"] = self.get_env_state(physics)
obs["images"] = {}
obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top")
obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle")
obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close")
return obs
def get_reward(self, physics):
# return whether left gripper is holding the box
raise NotImplementedError
class TransferCubeTask(BimanualViperXTask):
def __init__(self, random=None):
super().__init__(random=random)
self.max_reward = 4
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
# reset qpos, control and box position
with physics.reset_context():
physics.named.data.qpos[:16] = START_ARM_POSE
np.copyto(physics.data.ctrl, START_ARM_POSE)
assert BOX_POSE[0] is not None
physics.named.data.qpos[-7:] = BOX_POSE[0]
# print(f"{BOX_POSE=}")
super().initialize_episode(physics)
@staticmethod
def get_env_state(physics):
env_state = physics.data.qpos.copy()[16:]
return env_state
def get_reward(self, physics):
# return whether left gripper is holding the box
all_contact_pairs = []
for i_contact in range(physics.data.ncon):
id_geom_1 = physics.data.contact[i_contact].geom1
id_geom_2 = physics.data.contact[i_contact].geom2
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
contact_pair = (name_geom_1, name_geom_2)
all_contact_pairs.append(contact_pair)
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
touch_table = ("red_box", "table") in all_contact_pairs
reward = 0
if touch_right_gripper:
reward = 1
if touch_right_gripper and not touch_table: # lifted
reward = 2
if touch_left_gripper: # attempted transfer
reward = 3
if touch_left_gripper and not touch_table: # successful transfer
reward = 4
return reward
class InsertionTask(BimanualViperXTask):
def __init__(self, random=None):
super().__init__(random=random)
self.max_reward = 4
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
# reset qpos, control and box position
with physics.reset_context():
physics.named.data.qpos[:16] = START_ARM_POSE
np.copyto(physics.data.ctrl, START_ARM_POSE)
assert BOX_POSE[0] is not None
physics.named.data.qpos[-7 * 2 :] = BOX_POSE[0] # two objects
# print(f"{BOX_POSE=}")
super().initialize_episode(physics)
@staticmethod
def get_env_state(physics):
env_state = physics.data.qpos.copy()[16:]
return env_state
def get_reward(self, physics):
# return whether peg touches the pin
all_contact_pairs = []
for i_contact in range(physics.data.ncon):
id_geom_1 = physics.data.contact[i_contact].geom1
id_geom_2 = physics.data.contact[i_contact].geom2
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
contact_pair = (name_geom_1, name_geom_2)
all_contact_pairs.append(contact_pair)
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
touch_left_gripper = (
("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
)
peg_touch_table = ("red_peg", "table") in all_contact_pairs
socket_touch_table = (
("socket-1", "table") in all_contact_pairs
or ("socket-2", "table") in all_contact_pairs
or ("socket-3", "table") in all_contact_pairs
or ("socket-4", "table") in all_contact_pairs
)
peg_touch_socket = (
("red_peg", "socket-1") in all_contact_pairs
or ("red_peg", "socket-2") in all_contact_pairs
or ("red_peg", "socket-3") in all_contact_pairs
or ("red_peg", "socket-4") in all_contact_pairs
)
pin_touched = ("red_peg", "pin") in all_contact_pairs
reward = 0
if touch_left_gripper and touch_right_gripper: # touch both
reward = 1
if (
touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table)
): # grasp both
reward = 2
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
reward = 3
if pin_touched: # successful insertion
reward = 4
return reward

View File

@ -0,0 +1,263 @@
import collections
import numpy as np
from dm_control.suite import base
from lerobot.common.envs.aloha.constants import (
PUPPET_GRIPPER_POSITION_CLOSE,
START_ARM_POSE,
normalize_puppet_gripper_position,
normalize_puppet_gripper_velocity,
unnormalize_puppet_gripper_position,
)
from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
"""
Environment for simulated robot bi-manual manipulation, with end-effector control.
Action space: [left_arm_pose (7), # position and quaternion for end effector
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
right_arm_pose (7), # position and quaternion for end effector
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
right_arm_qpos (6), # absolute joint position
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
right_arm_qvel (6), # absolute joint velocity (rad)
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
"""
class BimanualViperXEndEffectorTask(base.Task):
def __init__(self, random=None):
super().__init__(random=random)
def before_step(self, action, physics):
a_len = len(action) // 2
action_left = action[:a_len]
action_right = action[a_len:]
# set mocap position and quat
# left
np.copyto(physics.data.mocap_pos[0], action_left[:3])
np.copyto(physics.data.mocap_quat[0], action_left[3:7])
# right
np.copyto(physics.data.mocap_pos[1], action_right[:3])
np.copyto(physics.data.mocap_quat[1], action_right[3:7])
# set gripper
g_left_ctrl = unnormalize_puppet_gripper_position(action_left[7])
g_right_ctrl = unnormalize_puppet_gripper_position(action_right[7])
np.copyto(physics.data.ctrl, np.array([g_left_ctrl, -g_left_ctrl, g_right_ctrl, -g_right_ctrl]))
def initialize_robots(self, physics):
# reset joint position
physics.named.data.qpos[:16] = START_ARM_POSE
# reset mocap to align with end effector
# to obtain these numbers:
# (1) make an ee_sim env and reset to the same start_pose
# (2) get env._physics.named.data.xpos['vx300s_left/gripper_link']
# get env._physics.named.data.xquat['vx300s_left/gripper_link']
# repeat the same for right side
np.copyto(physics.data.mocap_pos[0], [-0.31718881, 0.5, 0.29525084])
np.copyto(physics.data.mocap_quat[0], [1, 0, 0, 0])
# right
np.copyto(physics.data.mocap_pos[1], np.array([0.31718881, 0.49999888, 0.29525084]))
np.copyto(physics.data.mocap_quat[1], [1, 0, 0, 0])
# reset gripper control
close_gripper_control = np.array(
[
PUPPET_GRIPPER_POSITION_CLOSE,
-PUPPET_GRIPPER_POSITION_CLOSE,
PUPPET_GRIPPER_POSITION_CLOSE,
-PUPPET_GRIPPER_POSITION_CLOSE,
]
)
np.copyto(physics.data.ctrl, close_gripper_control)
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
super().initialize_episode(physics)
@staticmethod
def get_qpos(physics):
qpos_raw = physics.data.qpos.copy()
left_qpos_raw = qpos_raw[:8]
right_qpos_raw = qpos_raw[8:16]
left_arm_qpos = left_qpos_raw[:6]
right_arm_qpos = right_qpos_raw[:6]
left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])]
right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])]
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
@staticmethod
def get_qvel(physics):
qvel_raw = physics.data.qvel.copy()
left_qvel_raw = qvel_raw[:8]
right_qvel_raw = qvel_raw[8:16]
left_arm_qvel = left_qvel_raw[:6]
right_arm_qvel = right_qvel_raw[:6]
left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])]
right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])]
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
@staticmethod
def get_env_state(physics):
raise NotImplementedError
def get_observation(self, physics):
# note: it is important to do .copy()
obs = collections.OrderedDict()
obs["qpos"] = self.get_qpos(physics)
obs["qvel"] = self.get_qvel(physics)
obs["env_state"] = self.get_env_state(physics)
obs["images"] = {}
obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top")
obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle")
obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close")
# used in scripted policy to obtain starting pose
obs["mocap_pose_left"] = np.concatenate(
[physics.data.mocap_pos[0], physics.data.mocap_quat[0]]
).copy()
obs["mocap_pose_right"] = np.concatenate(
[physics.data.mocap_pos[1], physics.data.mocap_quat[1]]
).copy()
# used when replaying joint trajectory
obs["gripper_ctrl"] = physics.data.ctrl.copy()
return obs
def get_reward(self, physics):
raise NotImplementedError
class TransferCubeEndEffectorTask(BimanualViperXEndEffectorTask):
def __init__(self, random=None):
super().__init__(random=random)
self.max_reward = 4
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
self.initialize_robots(physics)
# randomize box position
cube_pose = sample_box_pose()
box_start_idx = physics.model.name2id("red_box_joint", "joint")
np.copyto(physics.data.qpos[box_start_idx : box_start_idx + 7], cube_pose)
# print(f"randomized cube position to {cube_position}")
super().initialize_episode(physics)
@staticmethod
def get_env_state(physics):
env_state = physics.data.qpos.copy()[16:]
return env_state
def get_reward(self, physics):
# return whether left gripper is holding the box
all_contact_pairs = []
for i_contact in range(physics.data.ncon):
id_geom_1 = physics.data.contact[i_contact].geom1
id_geom_2 = physics.data.contact[i_contact].geom2
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
contact_pair = (name_geom_1, name_geom_2)
all_contact_pairs.append(contact_pair)
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
touch_table = ("red_box", "table") in all_contact_pairs
reward = 0
if touch_right_gripper:
reward = 1
if touch_right_gripper and not touch_table: # lifted
reward = 2
if touch_left_gripper: # attempted transfer
reward = 3
if touch_left_gripper and not touch_table: # successful transfer
reward = 4
return reward
class InsertionEndEffectorTask(BimanualViperXEndEffectorTask):
def __init__(self, random=None):
super().__init__(random=random)
self.max_reward = 4
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
self.initialize_robots(physics)
# randomize peg and socket position
peg_pose, socket_pose = sample_insertion_pose()
def id2index(j_id):
return 16 + (j_id - 16) * 7 # first 16 is robot qpos, 7 is pose dim # hacky
peg_start_id = physics.model.name2id("red_peg_joint", "joint")
peg_start_idx = id2index(peg_start_id)
np.copyto(physics.data.qpos[peg_start_idx : peg_start_idx + 7], peg_pose)
# print(f"randomized cube position to {cube_position}")
socket_start_id = physics.model.name2id("blue_socket_joint", "joint")
socket_start_idx = id2index(socket_start_id)
np.copyto(physics.data.qpos[socket_start_idx : socket_start_idx + 7], socket_pose)
# print(f"randomized cube position to {cube_position}")
super().initialize_episode(physics)
@staticmethod
def get_env_state(physics):
env_state = physics.data.qpos.copy()[16:]
return env_state
def get_reward(self, physics):
# return whether peg touches the pin
all_contact_pairs = []
for i_contact in range(physics.data.ncon):
id_geom_1 = physics.data.contact[i_contact].geom1
id_geom_2 = physics.data.contact[i_contact].geom2
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
contact_pair = (name_geom_1, name_geom_2)
all_contact_pairs.append(contact_pair)
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
touch_left_gripper = (
("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
)
peg_touch_table = ("red_peg", "table") in all_contact_pairs
socket_touch_table = (
("socket-1", "table") in all_contact_pairs
or ("socket-2", "table") in all_contact_pairs
or ("socket-3", "table") in all_contact_pairs
or ("socket-4", "table") in all_contact_pairs
)
peg_touch_socket = (
("red_peg", "socket-1") in all_contact_pairs
or ("red_peg", "socket-2") in all_contact_pairs
or ("red_peg", "socket-3") in all_contact_pairs
or ("red_peg", "socket-4") in all_contact_pairs
)
pin_touched = ("red_peg", "pin") in all_contact_pairs
reward = 0
if touch_left_gripper and touch_right_gripper: # touch both
reward = 1
if (
touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table)
): # grasp both
reward = 2
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
reward = 3
if pin_touched: # successful insertion
reward = 4
return reward

View File

@ -0,0 +1,39 @@
import numpy as np
def sample_box_pose():
x_range = [0.0, 0.2]
y_range = [0.4, 0.6]
z_range = [0.05, 0.05]
ranges = np.vstack([x_range, y_range, z_range])
cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
cube_quat = np.array([1, 0, 0, 0])
return np.concatenate([cube_position, cube_quat])
def sample_insertion_pose():
# Peg
x_range = [0.1, 0.2]
y_range = [0.4, 0.6]
z_range = [0.05, 0.05]
ranges = np.vstack([x_range, y_range, z_range])
peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
peg_quat = np.array([1, 0, 0, 0])
peg_pose = np.concatenate([peg_position, peg_quat])
# Socket
x_range = [-0.2, -0.1]
y_range = [0.4, 0.6]
z_range = [0.05, 0.05]
ranges = np.vstack([x_range, y_range, z_range])
socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
socket_quat = np.array([1, 0, 0, 0])
socket_pose = np.concatenate([socket_position, socket_quat])
return peg_pose, socket_pose

View File

@ -23,6 +23,11 @@ def make_env(cfg, transform=None):
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
clsfunc = PushtEnv
elif cfg.env.name == "aloha":
from lerobot.common.envs.aloha.env import AlohaEnv
kwargs["task"] = cfg.env.task
clsfunc = AlohaEnv
else:
raise ValueError(cfg.env.name)

View File

@ -11,61 +11,52 @@ from torchrl.data.tensor_specs import (
DiscreteTensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs import EnvBase
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
from lerobot.common.envs.abstract import AbstractEnv
from lerobot.common.utils import set_seed
_has_gym = importlib.util.find_spec("gym") is not None
class PushtEnv(EnvBase):
class PushtEnv(AbstractEnv):
def __init__(
self,
task="pusht",
frame_skip: int = 1,
from_pixels: bool = False,
pixels_only: bool = False,
image_size=None,
seed=1337,
device="cpu",
num_prev_obs=0,
num_prev_obs=1,
num_prev_action=0,
):
super().__init__(device=device, batch_size=[])
self.frame_skip = frame_skip
self.from_pixels = from_pixels
self.pixels_only = pixels_only
self.image_size = image_size
self.num_prev_obs = num_prev_obs
self.num_prev_action = num_prev_action
if pixels_only:
assert from_pixels
if from_pixels:
assert image_size
super().__init__(
task=task,
frame_skip=frame_skip,
from_pixels=from_pixels,
pixels_only=pixels_only,
image_size=image_size,
seed=seed,
device=device,
num_prev_obs=num_prev_obs,
num_prev_action=num_prev_action,
)
def _make_env(self):
if not _has_gym:
raise ImportError("Cannot import gym.")
# TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on)
# from lerobot.common.envs.pusht.pusht_env import PushTEnv
if not from_pixels:
if not self.from_pixels:
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
from lerobot.common.envs.pusht.pusht_image_env import PushTImageEnv
self._env = PushTImageEnv(render_size=self.image_size)
self._make_spec()
self._current_seed = self.set_seed(seed)
if self.num_prev_obs > 0:
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
if self.num_prev_action > 0:
raise NotImplementedError()
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
def render(self, mode="rgb_array", width=384, height=384):
if width != height:
raise NotImplementedError()
@ -122,6 +113,8 @@ class PushtEnv(EnvBase):
)
else:
raise NotImplementedError()
self.call_rendering_hooks()
return td
def _step(self, tensordict: TensorDict):
@ -154,6 +147,8 @@ class PushtEnv(EnvBase):
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
self.call_rendering_hooks()
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
@ -175,9 +170,9 @@ class PushtEnv(EnvBase):
obs["image"] = BoundedTensorSpec(
low=0,
high=1,
high=255,
shape=image_shape,
dtype=torch.float32,
dtype=torch.uint8,
device=self.device,
)
if not self.pixels_only:

View File

@ -25,7 +25,7 @@ class PushTImageEnv(PushTEnv):
img = super()._render_frame(mode="rgb_array")
agent_pos = np.array(self.agent.position)
img_obs = np.moveaxis(img.astype(np.float32) / 255, -1, 0)
img_obs = np.moveaxis(img, -1, 0)
obs = {"image": img_obs, "agent_pos": agent_pos}
# draw action

View File

@ -1,6 +1,8 @@
import importlib
from collections import deque
from typing import Optional
import einops
import numpy as np
import torch
from tensordict import TensorDict
@ -10,9 +12,9 @@ from torchrl.data.tensor_specs import (
DiscreteTensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs import EnvBase
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
from lerobot.common.envs.abstract import AbstractEnv
from lerobot.common.utils import set_seed
MAX_NUM_ACTIONS = 4
@ -21,7 +23,7 @@ _has_gym = importlib.util.find_spec("gym") is not None
_has_simxarm = importlib.util.find_spec("simxarm") is not None and _has_gym
class SimxarmEnv(EnvBase):
class SimxarmEnv(AbstractEnv):
def __init__(
self,
task,
@ -31,19 +33,22 @@ class SimxarmEnv(EnvBase):
image_size=None,
seed=1337,
device="cpu",
num_prev_obs=0,
num_prev_action=0,
):
super().__init__(device=device, batch_size=[])
self.task = task
self.frame_skip = frame_skip
self.from_pixels = from_pixels
self.pixels_only = pixels_only
self.image_size = image_size
if pixels_only:
assert from_pixels
if from_pixels:
assert image_size
super().__init__(
task=task,
frame_skip=frame_skip,
from_pixels=from_pixels,
pixels_only=pixels_only,
image_size=image_size,
seed=seed,
device=device,
num_prev_obs=num_prev_obs,
num_prev_action=num_prev_action,
)
def _make_env(self):
if not _has_simxarm:
raise ImportError("Cannot import simxarm.")
if not _has_gym:
@ -63,9 +68,6 @@ class SimxarmEnv(EnvBase):
if "w" not in TASKS[self.task]["action_space"]:
self._action_padding[-1] = 1.0
self._make_spec()
self.set_seed(seed)
def render(self, mode="rgb_array", width=384, height=384):
return self._env.render(mode, width=width, height=height)
@ -90,15 +92,33 @@ class SimxarmEnv(EnvBase):
if td is None or td.is_empty():
raw_obs = self._env.reset()
obs = self._format_raw_obs(raw_obs)
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue = deque(
[obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue = deque(
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
td = TensorDict(
{
"observation": self._format_raw_obs(raw_obs),
"observation": TensorDict(obs, batch_size=[]),
"done": torch.tensor([False], dtype=torch.bool),
},
batch_size=[],
)
else:
raise NotImplementedError()
self.call_rendering_hooks()
return td
def _step(self, tensordict: TensorDict):
@ -108,10 +128,32 @@ class SimxarmEnv(EnvBase):
action = np.concatenate([action, self._action_padding])
# TODO(rcadene): add info["is_success"] and info["success"] ?
sum_reward = 0
for _ in range(self.frame_skip):
raw_obs, reward, done, info = self._env.step(action)
if action.ndim == 1:
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
else:
if self.frame_skip > 1:
raise NotImplementedError()
num_action_steps = action.shape[0]
for i in range(num_action_steps):
raw_obs, reward, done, info = self._env.step(action[i])
sum_reward += reward
obs = self._format_raw_obs(raw_obs)
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"])
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
self.call_rendering_hooks()
td = TensorDict(
{
"observation": self._format_raw_obs(raw_obs),
@ -126,23 +168,36 @@ class SimxarmEnv(EnvBase):
def _make_spec(self):
obs = {}
if self.from_pixels:
image_shape = (3, self.image_size, self.image_size)
if self.num_prev_obs > 0:
image_shape = (self.num_prev_obs + 1, *image_shape)
obs["image"] = BoundedTensorSpec(
low=0,
high=255,
shape=(3, self.image_size, self.image_size),
shape=image_shape,
dtype=torch.uint8,
device=self.device,
)
if not self.pixels_only:
state_shape = (len(self._env.robot_state),)
if self.num_prev_obs > 0:
state_shape = (self.num_prev_obs + 1, *state_shape)
obs["state"] = UnboundedContinuousTensorSpec(
shape=(len(self._env.robot_state),),
shape=state_shape,
dtype=torch.float32,
device=self.device,
)
else:
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
state_shape = self._env.observation_space["observation"].shape
if self.num_prev_obs > 0:
state_shape = (self.num_prev_obs + 1, *state_shape)
obs["state"] = UnboundedContinuousTensorSpec(
shape=self._env.observation_space["observation"].shape,
# TODO:
shape=state_shape,
dtype=torch.float32,
device=self.device,
)

View File

@ -1,5 +1,6 @@
from typing import Sequence
import torch
from tensordict import TensorDictBase
from tensordict.nn import dispatch
from tensordict.utils import NestedKey
@ -7,19 +8,45 @@ from torchrl.envs.transforms import ObservationTransform, Transform
class Prod(ObservationTransform):
invertible = True
def __init__(self, in_keys: Sequence[NestedKey], prod: float):
super().__init__()
self.in_keys = in_keys
self.prod = prod
self.original_dtypes = {}
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
# _reset is called once when the environment reset to normalize the first observation
tensordict_reset = self._call(tensordict_reset)
return tensordict_reset
@dispatch(source="in_keys", dest="out_keys")
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
return self._call(tensordict)
def _call(self, td):
for key in self.in_keys:
td[key] *= self.prod
if td.get(key, None) is None:
continue
self.original_dtypes[key] = td[key].dtype
td[key] = td[key].type(torch.float32) * self.prod
return td
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
for key in self.in_keys:
if td.get(key, None) is None:
continue
td[key] = (td[key] / self.prod).type(self.original_dtypes[key])
return td
def transform_observation_spec(self, obs_spec):
for key in self.in_keys:
obs_spec[key].space.high *= self.prod
if obs_spec.get(key, None) is None:
continue
obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod
obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod
obs_spec[key].dtype = torch.float32
return obs_spec

View File

@ -0,0 +1,115 @@
from typing import List
import torch
import torchvision
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
from .position_encoding import build_position_encoding
from .utils import NestedTensor, is_main_process
class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
produce nans.
"""
def __init__(self, n):
super().__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
eps = 1e-5
scale = w * (rv + eps).rsqrt()
bias = b - rm * scale
return x * scale + bias
class BackboneBase(nn.Module):
def __init__(
self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool
):
super().__init__()
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
# parameter.requires_grad_(False)
if return_interm_layers:
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
else:
return_layers = {"layer4": "0"}
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.num_channels = num_channels
def forward(self, tensor):
xs = self.body(tensor)
return xs
# out: Dict[str, NestedTensor] = {}
# for name, x in xs.items():
# m = tensor_list.mask
# assert m is not None
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
# out[name] = NestedTensor(x, mask)
# return out
class Backbone(BackboneBase):
"""ResNet backbone with frozen BatchNorm."""
def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, dilation: bool):
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(),
norm_layer=FrozenBatchNorm2d,
) # pretrained # TODO do we want frozen batch_norm??
num_channels = 512 if name in ("resnet18", "resnet34") else 2048
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)
def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out: List[NestedTensor] = []
pos = []
for _, x in xs.items():
out.append(x)
# position encoding
pos.append(self[1](x).to(x.dtype))
return out, pos
def build_backbone(args):
position_embedding = build_position_encoding(args)
train_backbone = args.lr_backbone > 0
return_interm_layers = args.masks
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
model = Joiner(backbone, position_embedding)
model.num_channels = backbone.num_channels
return model

View File

@ -0,0 +1,212 @@
import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
from .backbone import build_backbone
from .transformer import TransformerEncoder, TransformerEncoderLayer, build_transformer
def reparametrize(mu, logvar):
std = logvar.div(2).exp()
eps = Variable(std.data.new(std.size()).normal_())
return mu + std * eps
def get_sinusoid_encoding_table(n_position, d_hid):
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
class DETRVAE(nn.Module):
"""This is the DETR module that performs object detection"""
def __init__(
self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names, vae
):
"""Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
state_dim: robot state dimension of the environment
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.num_queries = num_queries
self.camera_names = camera_names
self.transformer = transformer
self.encoder = encoder
self.vae = vae
hidden_dim = transformer.d_model
self.action_head = nn.Linear(hidden_dim, action_dim)
self.is_pad_head = nn.Linear(hidden_dim, 1)
self.query_embed = nn.Embedding(num_queries, hidden_dim)
if backbones is not None:
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
self.backbones = nn.ModuleList(backbones)
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
else:
# input_dim = 14 + 7 # robot_state + env_state
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
# TODO(rcadene): understand what is env_state, and why it needs to be 7
self.input_proj_env_state = nn.Linear(state_dim // 2, hidden_dim)
self.pos = torch.nn.Embedding(2, hidden_dim)
self.backbones = None
# encoder extra parameters
self.latent_dim = 32 # final size of latent z # TODO tune
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding
self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
self.latent_proj = nn.Linear(
hidden_dim, self.latent_dim * 2
) # project hidden state to latent std, var
self.register_buffer(
"pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim)
) # [CLS], qpos, a_seq
# decoder extra parameters
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
self.additional_pos_embed = nn.Embedding(
2, hidden_dim
) # learned position embedding for proprio and latent
def forward(self, qpos, image, env_state, actions=None, is_pad=None):
"""
qpos: batch, qpos_dim
image: batch, num_cam, channel, height, width
env_state: None
actions: batch, seq, action_dim
"""
is_training = actions is not None # train or val
bs, _ = qpos.shape
### Obtain latent z from action sequence
if self.vae and is_training:
# project action sequence to embedding dim, and concat with a CLS token
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
cls_embed = self.cls_embed.weight # (1, hidden_dim)
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
encoder_input = torch.cat(
[cls_embed, qpos_embed, action_embed], axis=1
) # (bs, seq+1, hidden_dim)
encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
# do not mask cls token
# cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
# is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
# obtain position embedding
pos_embed = self.pos_table.clone().detach()
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
# query model
encoder_output = self.encoder(encoder_input, pos=pos_embed) # , src_key_padding_mask=is_pad)
encoder_output = encoder_output[0] # take cls output only
latent_info = self.latent_proj(encoder_output)
mu = latent_info[:, : self.latent_dim]
logvar = latent_info[:, self.latent_dim :]
latent_sample = reparametrize(mu, logvar)
latent_input = self.latent_out_proj(latent_sample)
else:
mu = logvar = None
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
latent_input = self.latent_out_proj(latent_sample)
if self.backbones is not None:
# Image observation features and position embeddings
all_cam_features = []
all_cam_pos = []
for cam_id, _ in enumerate(self.camera_names):
features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
features = features[0] # take the last layer feature
pos = pos[0]
all_cam_features.append(self.input_proj(features))
all_cam_pos.append(pos)
# proprioception features
proprio_input = self.input_proj_robot_state(qpos)
# fold camera dimension into width dimension
src = torch.cat(all_cam_features, axis=3)
pos = torch.cat(all_cam_pos, axis=3)
hs = self.transformer(
src,
None,
self.query_embed.weight,
pos,
latent_input,
proprio_input,
self.additional_pos_embed.weight,
)[0]
else:
qpos = self.input_proj_robot_state(qpos)
env_state = self.input_proj_env_state(env_state)
transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
a_hat = self.action_head(hs)
is_pad_hat = self.is_pad_head(hs)
return a_hat, is_pad_hat, [mu, logvar]
def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
if hidden_depth == 0:
mods = [nn.Linear(input_dim, output_dim)]
else:
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
for _ in range(hidden_depth - 1):
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
mods.append(nn.Linear(hidden_dim, output_dim))
trunk = nn.Sequential(*mods)
return trunk
def build_encoder(args):
d_model = args.hidden_dim # 256
dropout = args.dropout # 0.1
nhead = args.nheads # 8
dim_feedforward = args.dim_feedforward # 2048
num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
normalize_before = args.pre_norm # False
activation = "relu"
encoder_layer = TransformerEncoderLayer(
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
return encoder
def build(args):
# From state
# backbone = None # from state for now, no need for conv nets
# From image
backbones = []
backbone = build_backbone(args)
backbones.append(backbone)
transformer = build_transformer(args)
encoder = build_encoder(args)
model = DETRVAE(
backbones,
transformer,
encoder,
state_dim=args.state_dim,
action_dim=args.action_dim,
num_queries=args.num_queries,
camera_names=args.camera_names,
vae=args.vae,
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of parameters: {:.2f}M".format(n_parameters / 1e6))
return model

View File

@ -0,0 +1,217 @@
import logging
import time
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
import torchvision.transforms as transforms
from lerobot.common.policies.act.detr_vae import build
def build_act_model_and_optimizer(cfg):
model = build(cfg)
param_dicts = [
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
{
"params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
"lr": cfg.lr_backbone,
},
]
optimizer = torch.optim.AdamW(param_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay)
return model, optimizer
def kl_divergence(mu, logvar):
batch_size = mu.size(0)
assert batch_size != 0
if mu.data.ndimension() == 4:
mu = mu.view(mu.size(0), mu.size(1))
if logvar.data.ndimension() == 4:
logvar = logvar.view(logvar.size(0), logvar.size(1))
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
total_kld = klds.sum(1).mean(0, True)
dimension_wise_kld = klds.mean(0)
mean_kld = klds.mean(1).mean(0, True)
return total_kld, dimension_wise_kld, mean_kld
class ActionChunkingTransformerPolicy(nn.Module):
def __init__(self, cfg, device, n_action_steps=1):
super().__init__()
self.cfg = cfg
self.n_action_steps = n_action_steps
self.device = device
self.model, self.optimizer = build_act_model_and_optimizer(cfg)
self.kl_weight = self.cfg.kl_weight
logging.info(f"KL Weight {self.kl_weight}")
self.to(self.device)
def update(self, replay_buffer, step):
del step
start_time = time.time()
self.train()
num_slices = self.cfg.batch_size
batch_size = self.cfg.horizon * num_slices
assert batch_size % self.cfg.horizon == 0
assert batch_size % num_slices == 0
def process_batch(batch, horizon, num_slices):
# trajectory t = 64, horizon h = 16
# (t h) ... -> t h ...
batch = batch.reshape(num_slices, horizon)
image = batch["observation", "image", "top"]
image = image[:, 0] # first observation t=0
# batch, num_cam, channel, height, width
image = image.unsqueeze(1)
assert image.ndim == 5
image = image.float()
state = batch["observation", "state"]
state = state[:, 0] # first observation t=0
# batch, qpos_dim
assert state.ndim == 2
action = batch["action"]
# batch, seq, action_dim
assert action.ndim == 3
assert action.shape[1] == horizon
if self.cfg.n_obs_steps > 1:
raise NotImplementedError()
# # keep first n observations of the slice corresponding to t=[-1,0]
# image = image[:, : self.cfg.n_obs_steps]
# state = state[:, : self.cfg.n_obs_steps]
out = {
"obs": {
"image": image.to(self.device, non_blocking=True),
"agent_pos": state.to(self.device, non_blocking=True),
},
"action": action.to(self.device, non_blocking=True),
}
return out
batch = replay_buffer.sample(batch_size)
batch = process_batch(batch, self.cfg.horizon, num_slices)
data_s = time.time() - start_time
loss = self.compute_loss(batch)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.cfg.grad_clip_norm,
error_if_nonfinite=False,
)
self.optimizer.step()
self.optimizer.zero_grad()
# self.lr_scheduler.step()
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
# "lr": self.lr_scheduler.get_last_lr()[0],
"lr": self.cfg.lr,
"data_s": data_s,
"update_s": time.time() - start_time,
}
return info
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)
def compute_loss(self, batch):
loss_dict = self._forward(
qpos=batch["obs"]["agent_pos"],
image=batch["obs"]["image"],
actions=batch["action"],
)
loss = loss_dict["loss"]
return loss
@torch.no_grad()
def forward(self, observation, step_count):
# TODO(rcadene): remove unused step_count
del step_count
self.eval()
# TODO(rcadene): remove unsqueeze hack to add bsize=1
observation["image", "top"] = observation["image", "top"].unsqueeze(0)
# observation["state"] = observation["state"].unsqueeze(0)
# TODO(rcadene): remove hack
# add 1 camera dimension
observation["image", "top"] = observation["image", "top"].unsqueeze(1)
obs_dict = {
"image": observation["image", "top"],
"agent_pos": observation["state"],
}
action = self._forward(qpos=obs_dict["agent_pos"], image=obs_dict["image"])
if self.cfg.temporal_agg:
# TODO(rcadene): implement temporal aggregation
raise NotImplementedError()
# all_time_actions[[t], t:t+num_queries] = action
# actions_for_curr_step = all_time_actions[:, t]
# actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
# actions_for_curr_step = actions_for_curr_step[actions_populated]
# k = 0.01
# exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
# exp_weights = exp_weights / exp_weights.sum()
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
# remove bsize=1
action = action.squeeze(0)
# take first predicted action or n first actions
action = action[0] if self.n_action_steps == 1 else action[: self.n_action_steps]
return action
def _forward(self, qpos, image, actions=None, is_pad=None):
env_state = None
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
image = normalize(image)
is_training = actions is not None
if is_training: # training time
actions = actions[:, : self.model.num_queries]
if is_pad is not None:
is_pad = is_pad[:, : self.model.num_queries]
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
l1 = all_l1.mean() if is_pad is None else (all_l1 * ~is_pad.unsqueeze(-1)).mean()
loss_dict = {}
loss_dict["l1"] = l1
if self.cfg.vae:
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
loss_dict["kl"] = total_kld[0]
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
else:
loss_dict["loss"] = loss_dict["l1"]
return loss_dict
else:
action, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
return action

View File

@ -0,0 +1,101 @@
"""
Various positional encodings for the transformer.
"""
import math
import torch
from torch import nn
from .utils import NestedTensor
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, tensor):
x = tensor
# mask = tensor_list.mask
# assert mask is not None
# not_mask = ~mask
not_mask = torch.ones_like(x[0, [0]])
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
class PositionEmbeddingLearned(nn.Module):
"""
Absolute pos embedding, learned.
"""
def __init__(self, num_pos_feats=256):
super().__init__()
self.row_embed = nn.Embedding(50, num_pos_feats)
self.col_embed = nn.Embedding(50, num_pos_feats)
self.reset_parameters()
def reset_parameters(self):
nn.init.uniform_(self.row_embed.weight)
nn.init.uniform_(self.col_embed.weight)
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
h, w = x.shape[-2:]
i = torch.arange(w, device=x.device)
j = torch.arange(h, device=x.device)
x_emb = self.col_embed(i)
y_emb = self.row_embed(j)
pos = (
torch.cat(
[
x_emb.unsqueeze(0).repeat(h, 1, 1),
y_emb.unsqueeze(1).repeat(1, w, 1),
],
dim=-1,
)
.permute(2, 0, 1)
.unsqueeze(0)
.repeat(x.shape[0], 1, 1, 1)
)
return pos
def build_position_encoding(args):
n_steps = args.hidden_dim // 2
if args.position_embedding in ("v2", "sine"):
# TODO find a better way of exposing other arguments
position_embedding = PositionEmbeddingSine(n_steps, normalize=True)
elif args.position_embedding in ("v3", "learned"):
position_embedding = PositionEmbeddingLearned(n_steps)
else:
raise ValueError(f"not supported {args.position_embedding}")
return position_embedding

View File

@ -0,0 +1,370 @@
"""
DETR Transformer class.
Copy-paste from torch.nn.Transformer with modifications:
* positional encodings are passed in MHattention
* extra LN at the end of encoder is removed
* decoder returns a stack of activations from all decoding layers
"""
import copy
from typing import Optional
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
class Transformer(nn.Module):
def __init__(
self,
d_model=512,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
return_intermediate_dec=False,
):
super().__init__()
encoder_layer = TransformerEncoderLayer(
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
decoder_layer = TransformerDecoderLayer(
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = TransformerDecoder(
decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec
)
self._reset_parameters()
self.d_model = d_model
self.nhead = nhead
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(
self,
src,
mask,
query_embed,
pos_embed,
latent_input=None,
proprio_input=None,
additional_pos_embed=None,
):
# TODO flatten only when input has H and W
if len(src.shape) == 4: # has H and W
# flatten NxCxHxW to HWxNxC
bs, c, h, w = src.shape
src = src.flatten(2).permute(2, 0, 1)
pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
# mask = mask.flatten(1)
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
addition_input = torch.stack([latent_input, proprio_input], axis=0)
src = torch.cat([addition_input, src], axis=0)
else:
assert len(src.shape) == 3
# flatten NxHWxC to HWxNxC
bs, hw, c = src.shape
src = src.permute(1, 0, 2)
pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
tgt = torch.zeros_like(query_embed)
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed)
hs = hs.transpose(1, 2)
return hs
class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(
self,
src,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
output = src
for layer in self.layers:
output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos)
if self.norm is not None:
output = self.norm(output)
return output
class TransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
def forward(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
output = tgt
intermediate = []
for layer in self.layers:
output = layer(
output,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos,
query_pos=query_pos,
)
if self.return_intermediate:
intermediate.append(self.norm(output))
if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
intermediate.pop()
intermediate.append(output)
if self.return_intermediate:
return torch.stack(intermediate)
return output.unsqueeze(0)
class TransformerEncoderLayer(nn.Module):
def __init__(
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
q = k = self.with_pos_embed(src, pos)
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
def forward_pre(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
src2 = self.norm1(src)
q = k = self.with_pos_embed(src2, pos)
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src2 = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
src = src + self.dropout2(src2)
return src
def forward(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
if self.normalize_before:
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
class TransformerDecoderLayer(nn.Module):
def __init__(
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
tgt2 = self.multihead_attn(
query=self.with_pos_embed(tgt, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward_pre(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(
query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)[0]
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
if self.normalize_before:
return self.forward_pre(
tgt,
memory,
tgt_mask,
memory_mask,
tgt_key_padding_mask,
memory_key_padding_mask,
pos,
query_pos,
)
return self.forward_post(
tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
)
def _get_clones(module, n):
return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
def build_transformer(args):
return Transformer(
d_model=args.hidden_dim,
dropout=args.dropout,
nhead=args.nheads,
dim_feedforward=args.dim_feedforward,
num_encoder_layers=args.enc_layers,
num_decoder_layers=args.dec_layers,
normalize_before=args.pre_norm,
return_intermediate_dec=True,
)
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")

View File

@ -0,0 +1,477 @@
"""
Misc functions, including distributed helpers.
Mostly copy-paste from torchvision references.
"""
import datetime
import os
import pickle
import subprocess
import time
from collections import defaultdict, deque
from typing import List, Optional
import torch
import torch.distributed as dist
# needed due to empty tensor bug in pytorch and torchvision 0.5
import torchvision
from packaging import version
from torch import Tensor
if version.parse(torchvision.__version__) < version.parse("0.7"):
from torchvision.ops import _new_empty_tensor
from torchvision.ops.misc import _output_size
class SmoothedValue:
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
)
def all_gather(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to("cuda")
# obtain Tensor size of each rank
local_size = torch.tensor([tensor.numel()], device="cuda")
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
if local_size != max_size:
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list, strict=False):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def reduce_dict(input_dict, average=True):
"""
Args:
input_dict (dict): all the values will be reduced
average (bool): whether to do average or sum
Reduce the values in the dictionary from all processes so that all processes
have the averaged results. Returns a dict with the same fields as
input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.all_reduce(values)
if average:
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values, strict=False)} # noqa: C416
return reduced_dict
class MetricLogger:
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append("{}: {}".format(name, str(meter)))
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
if not header:
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
if torch.cuda.is_available():
log_msg = self.delimiter.join(
[
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
"max mem: {memory:.0f}",
]
)
else:
log_msg = self.delimiter.join(
[
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
]
)
mega_b = 1024.0 * 1024.0
for i, obj in enumerate(iterable):
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / mega_b,
)
)
else:
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
)
)
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable)))
def get_sha():
cwd = os.path.dirname(os.path.abspath(__file__))
def _run(command):
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
sha = "N/A"
diff = "clean"
branch = "N/A"
try:
sha = _run(["git", "rev-parse", "HEAD"])
subprocess.check_output(["git", "diff"], cwd=cwd)
diff = _run(["git", "diff-index", "HEAD"])
diff = "has uncommited changes" if diff else "clean"
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
except Exception:
pass
message = f"sha: {sha}, status: {diff}, branch: {branch}"
return message
def collate_fn(batch):
batch = list(zip(*batch, strict=False))
batch[0] = nested_tensor_from_tensor_list(batch[0])
return tuple(batch)
def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
class NestedTensor:
def __init__(self, tensors, mask: Optional[Tensor]):
self.tensors = tensors
self.mask = mask
def to(self, device):
# type: (Device) -> NestedTensor # noqa
cast_tensor = self.tensors.to(device)
mask = self.mask
if mask is not None:
assert mask is not None
cast_mask = mask.to(device)
else:
cast_mask = None
return NestedTensor(cast_tensor, cast_mask)
def decompose(self):
return self.tensors, self.mask
def __repr__(self):
return str(self.tensors)
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
# TODO make this more general
if tensor_list[0].ndim == 3:
if torchvision._is_tracing():
# nested_tensor_from_tensor_list() does not export well to ONNX
# call _onnx_nested_tensor_from_tensor_list() instead
return _onnx_nested_tensor_from_tensor_list(tensor_list)
# TODO make it support different-sized images
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask, strict=False):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], : img.shape[2]] = False
else:
raise ValueError("not supported")
return NestedTensor(tensor, mask)
# _onnx_nested_tensor_from_tensor_list() is an implementation of
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
@torch.jit.unused
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
max_size = []
for i in range(tensor_list[0].dim()):
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(
torch.int64
)
max_size.append(max_size_i)
max_size = tuple(max_size)
# work around for
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
# m[: img.shape[1], :img.shape[2]] = False
# which is not yet supported in onnx
padded_imgs = []
padded_masks = []
for img in tensor_list:
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape), strict=False)]
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
padded_imgs.append(padded_img)
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
padded_masks.append(padded_mask.to(torch.bool))
tensor = torch.stack(padded_imgs)
mask = torch.stack(padded_masks)
return NestedTensor(tensor, mask=mask)
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.rank % torch.cuda.device_count()
else:
print("Not using distributed mode")
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = "nccl"
print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
@torch.no_grad()
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
if target.numel() == 0:
return [torch.zeros([], device=output.device)]
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
"""
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
if version.parse(torchvision.__version__) < version.parse("0.7"):
if input.numel() > 0:
return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners)
output_shape = _output_size(2, input, size, scale_factor)
output_shape = list(input.shape[:-2]) + list(output_shape)
return _new_empty_tensor(input, output_shape)
else:
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)

View File

@ -17,6 +17,12 @@ def make_policy(cfg):
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
**cfg.policy,
)
elif cfg.policy.name == "act":
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
policy = ActionChunkingTransformerPolicy(
cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps
)
else:
raise ValueError(cfg.policy.name)

View File

@ -15,11 +15,11 @@ env:
task: sim_insertion_human
from_pixels: True
pixels_only: False
image_size: 96
image_size: [3, 480, 640]
action_repeat: 1
episode_length: 300
episode_length: 400
fps: ${fps}
policy:
state_dim: 2
action_dim: 2
state_dim: 14
action_dim: 14

View File

@ -0,0 +1,58 @@
# @package _global_
offline_steps: 1344000
online_steps: 0
eval_episodes: 1
eval_freq: 10000
save_freq: 100000
log_freq: 250
horizon: 100
n_obs_steps: 1
n_latency_steps: 0
# when temporal_agg=False, n_action_steps=horizon
n_action_steps: ${horizon}
policy:
name: act
pretrained_model_path:
lr: 1e-5
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
backbone: resnet18
num_queries: ${horizon} # chunk_size
horizon: ${horizon} # chunk_size
kl_weight: 10
hidden_dim: 512
dim_feedforward: 3200
enc_layers: 4
dec_layers: 7
nheads: 8
#camera_names: [top, front_close, left_pillar, right_pillar]
camera_names: [top]
position_embedding: sine
masks: false
dilation: false
dropout: 0.1
pre_norm: false
vae: true
batch_size: 8
per_alpha: 0.6
per_beta: 0.4
balanced_sampling: false
utd: 1
n_obs_steps: ${n_obs_steps}
temporal_agg: false
state_dim: ???
action_dim: ???

View File

@ -1,22 +0,0 @@
# TODO(rcadene): obsolete remove
import os
import zipfile
import gdown
def download():
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
download_path = "data.zip"
gdown.download(url, download_path, quiet=False)
print("Extracting...")
with zipfile.ZipFile(download_path, "r") as zip_f:
for member in zip_f.namelist():
if member.startswith("data/xarm") and member.endswith(".pkl"):
print(member)
zip_f.extract(member=member)
os.remove(download_path)
if __name__ == "__main__":
download()

View File

@ -38,27 +38,18 @@ def eval_policy(
successes = []
threads = []
for i in tqdm.tqdm(range(num_episodes)):
tensordict = env.reset()
ep_frames = []
if save_video or (return_first_video and i == 0):
def rendering_callback(env, td=None):
def render_frame(env):
ep_frames.append(env.render()) # noqa: B023
# render first frame before rollout
rendering_callback(env)
else:
rendering_callback = None
env.register_rendering_hook(render_frame)
with torch.inference_mode():
rollout = env.rollout(
max_steps=max_steps,
policy=policy,
callback=rendering_callback,
auto_reset=False,
tensordict=tensordict,
auto_cast_to_device=True,
)
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
@ -85,6 +76,8 @@ def eval_policy(
if return_first_video and i == 0:
first_video = stacked_frames.transpose(0, 3, 1, 2)
env.reset_rendering_hooks()
for thread in threads:
thread.join()

View File

@ -1,4 +1,5 @@
import logging
from pathlib import Path
import hydra
import numpy as np
@ -192,6 +193,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
num_episodes=cfg.eval_episodes,
max_steps=cfg.env.episode_length // cfg.n_action_steps,
return_first_video=True,
video_dir=Path(out_dir) / "eval",
save_video=True,
)
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
if cfg.wandb.enable:

479
poetry.lock generated
View File

@ -340,69 +340,69 @@ files = [
[[package]]
name = "cython"
version = "3.0.8"
version = "3.0.9"
description = "The Cython compiler for writing C extensions in the Python language."
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
files = [
{file = "Cython-3.0.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a846e0a38e2b24e9a5c5dc74b0e54c6e29420d88d1dafabc99e0fc0f3e338636"},
{file = "Cython-3.0.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45523fdc2b78d79b32834cc1cc12dc2ca8967af87e22a3ee1bff20e77c7f5520"},
{file = "Cython-3.0.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa0b7f3f841fe087410cab66778e2d3fb20ae2d2078a2be3dffe66c6574be39"},
{file = "Cython-3.0.8-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e87294e33e40c289c77a135f491cd721bd089f193f956f7b8ed5aa2d0b8c558f"},
{file = "Cython-3.0.8-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:a1df7a129344b1215c20096d33c00193437df1a8fcca25b71f17c23b1a44f782"},
{file = "Cython-3.0.8-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:13c2a5e57a0358da467d97667297bf820b62a1a87ae47c5f87938b9bb593acbd"},
{file = "Cython-3.0.8-cp310-cp310-win32.whl", hash = "sha256:96b028f044f5880e3cb18ecdcfc6c8d3ce9d0af28418d5ab464509f26d8adf12"},
{file = "Cython-3.0.8-cp310-cp310-win_amd64.whl", hash = "sha256:8140597a8b5cc4f119a1190f5a2228a84f5ca6d8d9ec386cfce24663f48b2539"},
{file = "Cython-3.0.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:aae26f9663e50caf9657148403d9874eea41770ecdd6caf381d177c2b1bb82ba"},
{file = "Cython-3.0.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:547eb3cdb2f8c6f48e6865d5a741d9dd051c25b3ce076fbca571727977b28ac3"},
{file = "Cython-3.0.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5a567d4b9ba70b26db89d75b243529de9e649a2f56384287533cf91512705bee"},
{file = "Cython-3.0.8-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:51d1426263b0e82fb22bda8ea60dc77a428581cc19e97741011b938445d383f1"},
{file = "Cython-3.0.8-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c26daaeccda072459b48d211415fd1e5507c06bcd976fa0d5b8b9f1063467d7b"},
{file = "Cython-3.0.8-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:289ce7838208211cd166e975865fd73b0649bf118170b6cebaedfbdaf4a37795"},
{file = "Cython-3.0.8-cp311-cp311-win32.whl", hash = "sha256:c8aa05f5e17f8042a3be052c24f2edc013fb8af874b0bf76907d16c51b4e7871"},
{file = "Cython-3.0.8-cp311-cp311-win_amd64.whl", hash = "sha256:000dc9e135d0eec6ecb2b40a5b02d0868a2f8d2e027a41b0fe16a908a9e6de02"},
{file = "Cython-3.0.8-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:90d3fe31db55685d8cb97d43b0ec39ef614fcf660f83c77ed06aa670cb0e164f"},
{file = "Cython-3.0.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e24791ddae2324e88e3c902a765595c738f19ae34ee66bfb1a6dac54b1833419"},
{file = "Cython-3.0.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2f020fa1c0552052e0660790b8153b79e3fc9a15dbd8f1d0b841fe5d204a6ae6"},
{file = "Cython-3.0.8-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18bfa387d7a7f77d7b2526af69a65dbd0b731b8d941aaff5becff8e21f6d7717"},
{file = "Cython-3.0.8-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:fe81b339cffd87c0069c6049b4d33e28bdd1874625ee515785bf42c9fdff3658"},
{file = "Cython-3.0.8-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:80fd94c076e1e1b1ee40a309be03080b75f413e8997cddcf401a118879863388"},
{file = "Cython-3.0.8-cp312-cp312-win32.whl", hash = "sha256:85077915a93e359a9b920280d214dc0cf8a62773e1f3d7d30fab8ea4daed670c"},
{file = "Cython-3.0.8-cp312-cp312-win_amd64.whl", hash = "sha256:0cb2dcc565c7851f75d496f724a384a790fab12d1b82461b663e66605bec429a"},
{file = "Cython-3.0.8-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:870d2a0a7e3cbd5efa65aecdb38d715ea337a904ea7bb22324036e78fb7068e7"},
{file = "Cython-3.0.8-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7e8f2454128974905258d86534f4fd4f91d2f1343605657ecab779d80c9d6d5e"},
{file = "Cython-3.0.8-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1949d6aa7bc792554bee2b67a9fe41008acbfe22f4f8df7b6ec7b799613a4b3"},
{file = "Cython-3.0.8-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c9f2c6e1b8f3bcd6cb230bac1843f85114780bb8be8614855b1628b36bb510e0"},
{file = "Cython-3.0.8-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:05d7eddc668ae7993643f32c7661f25544e791edb745758672ea5b1a82ecffa6"},
{file = "Cython-3.0.8-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:bfabe115deef4ada5d23c87bddb11289123336dcc14347011832c07db616dd93"},
{file = "Cython-3.0.8-cp36-cp36m-win32.whl", hash = "sha256:0c38c9f0bcce2df0c3347285863621be904ac6b64c5792d871130569d893efd7"},
{file = "Cython-3.0.8-cp36-cp36m-win_amd64.whl", hash = "sha256:6c46939c3983217d140999de7c238c3141f56b1ea349e47ca49cae899969aa2c"},
{file = "Cython-3.0.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:115f0a50f752da6c99941b103b5cb090da63eb206abbc7c2ad33856ffc73f064"},
{file = "Cython-3.0.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c9c0f29246734561c90f36e70ed0506b61aa3d044e4cc4cba559065a2a741fae"},
{file = "Cython-3.0.8-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ab75242869ff71e5665fe5c96f3378e79e792fa3c11762641b6c5afbbbbe026"},
{file = "Cython-3.0.8-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6717c06e9cfc6c1df18543cd31a21f5d8e378a40f70c851fa2d34f0597037abc"},
{file = "Cython-3.0.8-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:9d3f74388db378a3c6fd06e79a809ed98df3f56484d317b81ee762dbf3c263e0"},
{file = "Cython-3.0.8-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ae7ac561fd8253a9ae96311e91d12af5f701383564edc11d6338a7b60b285a6f"},
{file = "Cython-3.0.8-cp37-cp37m-win32.whl", hash = "sha256:97b2a45845b993304f1799664fa88da676ee19442b15fdcaa31f9da7e1acc434"},
{file = "Cython-3.0.8-cp37-cp37m-win_amd64.whl", hash = "sha256:9e2be2b340fea46fb849d378f9b80d3c08ff2e81e2bfbcdb656e2e3cd8c6b2dc"},
{file = "Cython-3.0.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2cde23c555470db3f149ede78b518e8274853745289c956a0e06ad8d982e4db9"},
{file = "Cython-3.0.8-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7990ca127e1f1beedaf8fc8bf66541d066ef4723ad7d8d47a7cbf842e0f47580"},
{file = "Cython-3.0.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b983c8e6803f016146c26854d9150ddad5662960c804ea7f0c752c9266752f0"},
{file = "Cython-3.0.8-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a973268d7ca1a2bdf78575e459a94a78e1a0a9bb62a7db0c50041949a73b02ff"},
{file = "Cython-3.0.8-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:61a237bc9dd23c7faef0fcfce88c11c65d0c9bb73c74ccfa408b3a012073c20e"},
{file = "Cython-3.0.8-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:3a3d67f079598af49e90ff9655bf85bd358f093d727eb21ca2708f467c489cae"},
{file = "Cython-3.0.8-cp38-cp38-win32.whl", hash = "sha256:17a642bb01a693e34c914106566f59844b4461665066613913463a719e0dd15d"},
{file = "Cython-3.0.8-cp38-cp38-win_amd64.whl", hash = "sha256:2cdfc32252f3b6dc7c94032ab744dcedb45286733443c294d8f909a4854e7f83"},
{file = "Cython-3.0.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fa97893d99385386925d00074654aeae3a98867f298d1e12ceaf38a9054a9bae"},
{file = "Cython-3.0.8-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f05c0bf9d085c031df8f583f0d506aa3be1692023de18c45d0aaf78685bbb944"},
{file = "Cython-3.0.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de892422582f5758bd8de187e98ac829330ec1007bc42c661f687792999988a7"},
{file = "Cython-3.0.8-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:314f2355a1f1d06e3c431eaad4708cf10037b5e91e4b231d89c913989d0bdafd"},
{file = "Cython-3.0.8-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:78825a3774211e7d5089730f00cdf7f473042acc9ceb8b9eeebe13ed3a5541de"},
{file = "Cython-3.0.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:df8093deabc55f37028190cf5e575c26aad23fc673f34b85d5f45076bc37ce39"},
{file = "Cython-3.0.8-cp39-cp39-win32.whl", hash = "sha256:1aca1b97e0095b3a9a6c33eada3f661a4ed0d499067d121239b193e5ba3bb4f0"},
{file = "Cython-3.0.8-cp39-cp39-win_amd64.whl", hash = "sha256:16873d78be63bd38ffb759da7ab82814b36f56c769ee02b1d5859560e4c3ac3c"},
{file = "Cython-3.0.8-py2.py3-none-any.whl", hash = "sha256:171b27051253d3f9108e9759e504ba59ff06e7f7ba944457f94deaf9c21bf0b6"},
{file = "Cython-3.0.8.tar.gz", hash = "sha256:8333423d8fd5765e7cceea3a9985dd1e0a5dfeb2734629e1a2ed2d6233d39de6"},
{file = "Cython-3.0.9-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:296bd30d4445ac61b66c9d766567f6e81a6e262835d261e903c60c891a6729d3"},
{file = "Cython-3.0.9-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f496b52845cb45568a69d6359a2c335135233003e708ea02155c10ce3548aa89"},
{file = "Cython-3.0.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:858c3766b9aa3ab8a413392c72bbab1c144a9766b7c7bfdef64e2e414363fa0c"},
{file = "Cython-3.0.9-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c0eb1e6ef036028a52525fd9a012a556f6dd4788a0e8755fe864ba0e70cde2ff"},
{file = "Cython-3.0.9-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c8191941073ea5896321de3c8c958fd66e5f304b0cd1f22c59edd0b86c4dd90d"},
{file = "Cython-3.0.9-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e32b016030bc72a8a22a1f21f470a2f57573761a4f00fbfe8347263f4fbdb9f1"},
{file = "Cython-3.0.9-cp310-cp310-win32.whl", hash = "sha256:d6f3ff1cd6123973fe03e0fb8ee936622f976c0c41138969975824d08886572b"},
{file = "Cython-3.0.9-cp310-cp310-win_amd64.whl", hash = "sha256:56f3b643dbe14449248bbeb9a63fe3878a24256664bc8c8ef6efd45d102596d8"},
{file = "Cython-3.0.9-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:35e6665a20d6b8a152d72b7fd87dbb2af6bb6b18a235b71add68122d594dbd41"},
{file = "Cython-3.0.9-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f92f4960c40ad027bd8c364c50db11104eadc59ffeb9e5b7f605ca2f05946e20"},
{file = "Cython-3.0.9-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38df37d0e732fbd9a2fef898788492e82b770c33d1e4ed12444bbc8a3b3f89c0"},
{file = "Cython-3.0.9-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ad7fd88ebaeaf2e76fd729a8919fae80dab3d6ac0005e28494261d52ff347a8f"},
{file = "Cython-3.0.9-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1365d5f76bf4d19df3d19ce932584c9bb76e9fb096185168918ef9b36e06bfa4"},
{file = "Cython-3.0.9-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c232e7f279388ac9625c3e5a5a9f0078a9334959c5d6458052c65bbbba895e1e"},
{file = "Cython-3.0.9-cp311-cp311-win32.whl", hash = "sha256:357e2fad46a25030b0c0496487e01a9dc0fdd0c09df0897f554d8ba3c1bc4872"},
{file = "Cython-3.0.9-cp311-cp311-win_amd64.whl", hash = "sha256:1315aee506506e8d69cf6631d8769e6b10131fdcc0eb66df2698f2a3ddaeeff2"},
{file = "Cython-3.0.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:157973807c2796addbed5fbc4d9c882ab34bbc60dc297ca729504901479d5df7"},
{file = "Cython-3.0.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00b105b5d050645dd59e6767bc0f18b48a4aa11c85f42ec7dd8181606f4059e3"},
{file = "Cython-3.0.9-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac5536d09bef240cae0416d5a703d298b74c7bbc397da803ac9d344e732d4369"},
{file = "Cython-3.0.9-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09c44501d476d16aaa4cbc29c87f8c0f54fc20e69b650d59cbfa4863426fc70c"},
{file = "Cython-3.0.9-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:cc9c3b9f20d8e298618e5ccd32083ca386e785b08f9893fbec4c50b6b85be772"},
{file = "Cython-3.0.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a30d96938c633e3ec37000ac3796525da71254ef109e66bdfd78f29891af6454"},
{file = "Cython-3.0.9-cp312-cp312-win32.whl", hash = "sha256:757ca93bdd80702546df4d610d2494ef2e74249cac4d5ba9464589fb464bd8a3"},
{file = "Cython-3.0.9-cp312-cp312-win_amd64.whl", hash = "sha256:1dc320a9905ab95414013f6de805efbff9e17bb5fb3b90bbac533f017bec8136"},
{file = "Cython-3.0.9-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:4ae349960ebe0da0d33724eaa7f1eb866688fe5434cc67ce4dbc06d6a719fbfc"},
{file = "Cython-3.0.9-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63d2537bf688247f76ded6dee28ebd26274f019309aef1eb4f2f9c5c482fde2d"},
{file = "Cython-3.0.9-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36f5a2dfc724bea1f710b649f02d802d80fc18320c8e6396684ba4a48412445a"},
{file = "Cython-3.0.9-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:deaf4197d4b0bcd5714a497158ea96a2bd6d0f9636095437448f7e06453cc83d"},
{file = "Cython-3.0.9-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:000af6deb7412eb7ac0c635ff5e637fb8725dd0a7b88cc58dfc2b3de14e701c4"},
{file = "Cython-3.0.9-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:15c7f5c2d35bed9aa5f2a51eaac0df23ae72f2dbacf62fc672dd6bfaa75d2d6f"},
{file = "Cython-3.0.9-cp36-cp36m-win32.whl", hash = "sha256:f49aa4970cd3bec66ac22e701def16dca2a49c59cceba519898dd7526e0be2c0"},
{file = "Cython-3.0.9-cp36-cp36m-win_amd64.whl", hash = "sha256:4558814fa025b193058d42eeee498a53d6b04b2980d01339fc2444b23fd98e58"},
{file = "Cython-3.0.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:539cd1d74fd61f6cfc310fa6bbbad5adc144627f2b7486a07075d4e002fd6aad"},
{file = "Cython-3.0.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3232926cd406ee02eabb732206f6e882c3aed9d58f0fea764013d9240405bcf"},
{file = "Cython-3.0.9-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33b6ac376538a7fc8c567b85d3c71504308a9318702ec0485dd66c059f3165cb"},
{file = "Cython-3.0.9-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2cc92504b5d22ac66031ffb827bd3a967fc75a5f0f76ab48bce62df19be6fdfd"},
{file = "Cython-3.0.9-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:22b8fae756c5c0d8968691bed520876de452f216c28ec896a00739a12dba3bd9"},
{file = "Cython-3.0.9-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:9cda0d92a09f3520f29bd91009f1194ba9600777c02c30c6d2d4ac65fb63e40d"},
{file = "Cython-3.0.9-cp37-cp37m-win32.whl", hash = "sha256:ec612418490941ed16c50c8d3784c7bdc4c4b2a10c361259871790b02ec8c1db"},
{file = "Cython-3.0.9-cp37-cp37m-win_amd64.whl", hash = "sha256:976c8d2bedc91ff6493fc973d38b2dc01020324039e2af0e049704a8e1b22936"},
{file = "Cython-3.0.9-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5055988b007c92256b6e9896441c3055556038c3497fcbf8c921a6c1fce90719"},
{file = "Cython-3.0.9-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9360606d964c2d0492a866464efcf9d0a92715644eede3f6a2aa696de54a137"},
{file = "Cython-3.0.9-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02c6e809f060bed073dc7cba1648077fe3b68208863d517c8b39f3920eecf9dd"},
{file = "Cython-3.0.9-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:95ed792c966f969cea7489c32ff90150b415c1f3567db8d5a9d489c7c1602dac"},
{file = "Cython-3.0.9-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8edd59d22950b400b03ca78d27dc694d2836a92ef0cac4f64cb4b2ff902f7e25"},
{file = "Cython-3.0.9-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:4cf0ed273bf60e97922fcbbdd380c39693922a597760160b4b4355e6078ca188"},
{file = "Cython-3.0.9-cp38-cp38-win32.whl", hash = "sha256:5eb9bd4ae12ebb2bc79a193d95aacf090fbd8d7013e11ed5412711650cb34934"},
{file = "Cython-3.0.9-cp38-cp38-win_amd64.whl", hash = "sha256:44457279da56e0f829bb1fc5a5dc0836e5d498dbcf9b2324f32f7cc9d2ec6569"},
{file = "Cython-3.0.9-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c4b419a1adc2af43f4660e2f6eaf1e4fac2dbac59490771eb8ac3d6063f22356"},
{file = "Cython-3.0.9-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f836192140f033b2319a0128936367c295c2b32e23df05b03b672a6015757ea"},
{file = "Cython-3.0.9-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fd198c1a7f8e9382904d622cc0efa3c184605881fd5262c64cbb7168c4c1ec5"},
{file = "Cython-3.0.9-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a274fe9ca5c53fafbcf5c8f262f8ad6896206a466f0eeb40aaf36a7951e957c0"},
{file = "Cython-3.0.9-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:158c38360bbc5063341b1e78d3737f1251050f89f58a3df0d10fb171c44262be"},
{file = "Cython-3.0.9-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8bf30b045f7deda0014b042c1b41c1d272facc762ab657529e3b05505888e878"},
{file = "Cython-3.0.9-cp39-cp39-win32.whl", hash = "sha256:9a001fd95c140c94d934078544ff60a3c46aca2dc86e75a76e4121d3cd1f4b33"},
{file = "Cython-3.0.9-cp39-cp39-win_amd64.whl", hash = "sha256:530c01c4aebba709c0ec9c7ecefe07177d0b9fd7ffee29450a118d92192ccbdf"},
{file = "Cython-3.0.9-py2.py3-none-any.whl", hash = "sha256:bf96417714353c5454c2e3238fca9338599330cf51625cdc1ca698684465646f"},
{file = "Cython-3.0.9.tar.gz", hash = "sha256:a2d354f059d1f055d34cfaa62c5b68bc78ac2ceab6407148d47fb508cf3ba4f3"},
]
[[package]]
@ -488,6 +488,37 @@ files = [
{file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"},
]
[[package]]
name = "dm-control"
version = "1.0.14"
description = "Continuous control environments and MuJoCo Python bindings."
optional = false
python-versions = ">=3.8"
files = [
{file = "dm_control-1.0.14-py3-none-any.whl", hash = "sha256:883c63244a7ebf598700a97564ed19fffd3479ca79efd090aed881609cdb9fc6"},
{file = "dm_control-1.0.14.tar.gz", hash = "sha256:def1ece747b6f175c581150826b50f1a6134086dab34f8f3fd2d088ea035cf3d"},
]
[package.dependencies]
absl-py = ">=0.7.0"
dm-env = "*"
dm-tree = "!=0.1.2"
glfw = "*"
labmaze = "*"
lxml = "*"
mujoco = ">=2.3.7"
numpy = ">=1.9.0"
protobuf = ">=3.19.4"
pyopengl = ">=3.1.4"
pyparsing = ">=3.0.0"
requests = "*"
scipy = "*"
setuptools = "!=50.0.0"
tqdm = "*"
[package.extras]
hdf5 = ["h5py"]
[[package]]
name = "dm-env"
version = "1.6"
@ -584,43 +615,6 @@ files = [
{file = "einops-0.7.0.tar.gz", hash = "sha256:b2b04ad6081a3b227080c9bf5e3ace7160357ff03043cd66cc5b2319eb7031d1"},
]
[[package]]
name = "etils"
version = "1.7.0"
description = "Collection of common python utils"
optional = false
python-versions = ">=3.10"
files = [
{file = "etils-1.7.0-py3-none-any.whl", hash = "sha256:61af8f7c242171de15e22e5da02d527cb9e677d11f8bcafe18fcc3548eee3e60"},
{file = "etils-1.7.0.tar.gz", hash = "sha256:97b68fd25e185683215286ef3a54e38199b6245f5fe8be6bedc1189be4256350"},
]
[package.dependencies]
fsspec = {version = "*", optional = true, markers = "extra == \"epath\""}
importlib_resources = {version = "*", optional = true, markers = "extra == \"epath\""}
typing_extensions = {version = "*", optional = true, markers = "extra == \"epy\""}
zipp = {version = "*", optional = true, markers = "extra == \"epath\""}
[package.extras]
all = ["etils[array-types]", "etils[eapp]", "etils[ecolab]", "etils[edc]", "etils[enp]", "etils[epath-gcs]", "etils[epath-s3]", "etils[epath]", "etils[epy]", "etils[etqdm]", "etils[etree-dm]", "etils[etree-jax]", "etils[etree-tf]", "etils[etree]"]
array-types = ["etils[enp]"]
dev = ["chex", "dataclass_array", "optree", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-subtests", "pytest-xdist", "torch"]
docs = ["etils[all,dev]", "sphinx-apitree[ext]"]
eapp = ["absl-py", "etils[epy]", "simple_parsing"]
ecolab = ["etils[enp]", "etils[epy]", "etils[etree]", "jupyter", "mediapy", "numpy", "packaging", "protobuf"]
edc = ["etils[epy]"]
enp = ["etils[epy]", "numpy"]
epath = ["etils[epy]", "fsspec", "importlib_resources", "typing_extensions", "zipp"]
epath-gcs = ["etils[epath]", "gcsfs"]
epath-s3 = ["etils[epath]", "s3fs"]
epy = ["typing_extensions"]
etqdm = ["absl-py", "etils[epy]", "tqdm"]
etree = ["etils[array-types]", "etils[enp]", "etils[epy]", "etils[etqdm]"]
etree-dm = ["dm-tree", "etils[etree]"]
etree-jax = ["etils[etree]", "jax[cpu]"]
etree-tf = ["etils[etree]", "tensorflow"]
lazy-imports = ["etils[ecolab]"]
[[package]]
name = "exceptiongroup"
version = "1.2.0"
@ -846,13 +840,13 @@ numpy = ">=1.17.3"
[[package]]
name = "huggingface-hub"
version = "0.21.3"
version = "0.21.4"
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
optional = false
python-versions = ">=3.8.0"
files = [
{file = "huggingface_hub-0.21.3-py3-none-any.whl", hash = "sha256:b183144336fdf2810a8c109822e0bb6ef1fd61c65da6fb60e8c3f658b7144016"},
{file = "huggingface_hub-0.21.3.tar.gz", hash = "sha256:26a15b604e4fc7bad37c467b76456543ec849386cbca9cd7e1e135f53e500423"},
{file = "huggingface_hub-0.21.4-py3-none-any.whl", hash = "sha256:df37c2c37fc6c82163cdd8a67ede261687d80d1e262526d6c0ce73b6b3630a7b"},
{file = "huggingface_hub-0.21.4.tar.gz", hash = "sha256:e1f4968c93726565a80edf6dc309763c7b546d0cfe79aa221206034d50155531"},
]
[package.dependencies]
@ -971,37 +965,22 @@ setuptools = "*"
[[package]]
name = "importlib-metadata"
version = "7.0.1"
version = "7.0.2"
description = "Read metadata from Python packages"
optional = false
python-versions = ">=3.8"
files = [
{file = "importlib_metadata-7.0.1-py3-none-any.whl", hash = "sha256:4805911c3a4ec7c3966410053e9ec6a1fecd629117df5adee56dfc9432a1081e"},
{file = "importlib_metadata-7.0.1.tar.gz", hash = "sha256:f238736bb06590ae52ac1fab06a3a9ef1d8dce2b7a35b5ab329371d6c8f5d2cc"},
{file = "importlib_metadata-7.0.2-py3-none-any.whl", hash = "sha256:f4bc4c0c070c490abf4ce96d715f68e95923320370efb66143df00199bb6c100"},
{file = "importlib_metadata-7.0.2.tar.gz", hash = "sha256:198f568f3230878cb1b44fbd7975f87906c22336dba2e4a7f05278c281fbd792"},
]
[package.dependencies]
zipp = ">=0.5"
[package.extras]
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"]
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
perf = ["ipython"]
testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"]
[[package]]
name = "importlib-resources"
version = "6.1.2"
description = "Read resources from Python packages"
optional = false
python-versions = ">=3.8"
files = [
{file = "importlib_resources-6.1.2-py3-none-any.whl", hash = "sha256:9a0a862501dc38b68adebc82970140c9e4209fc99601782925178f8386339938"},
{file = "importlib_resources-6.1.2.tar.gz", hash = "sha256:308abf8474e2dba5f867d279237cd4076482c3de7104a40b41426370e891549b"},
]
[package.extras]
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"]
testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "zipp (>=3.17)"]
testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"]
[[package]]
name = "iniconfig"
@ -1031,6 +1010,50 @@ MarkupSafe = ">=2.0"
[package.extras]
i18n = ["Babel (>=2.7)"]
[[package]]
name = "labmaze"
version = "1.0.6"
description = "LabMaze: DeepMind Lab's text maze generator."
optional = false
python-versions = "*"
files = [
{file = "labmaze-1.0.6-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b2ddef976dfd8d992b19cfa6c633f2eba7576d759c2082da534e3f727479a84a"},
{file = "labmaze-1.0.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:157efaa93228c8ccce5cae337902dd652093e0fba9d3a0f6506e4bee272bb66f"},
{file = "labmaze-1.0.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b3ce98b9541c5fe6a306e411e7d018121dd646f2c9978d763fad86f9f30c5f57"},
{file = "labmaze-1.0.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e6433bd49bc541791de8191040526fddfebb77151620eb04203453f43ee486a"},
{file = "labmaze-1.0.6-cp310-cp310-win_amd64.whl", hash = "sha256:6a507fc35961f1b1479708e2716f65e0d0611cefb55f31a77be29ce2339b6fef"},
{file = "labmaze-1.0.6-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:a0c2cb9dec971814ea9c5d7150af15fa3964482131fa969e0afb94bd224348af"},
{file = "labmaze-1.0.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2c6ba9538d819543f4be448d36b4926a3881e53646a2b331ebb5a1f353047d05"},
{file = "labmaze-1.0.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70635d1cdb0147a02efb6b3f607a52cdc51723bc3dcc42717a0d4ef55fa0a987"},
{file = "labmaze-1.0.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff472793238bd9b6dabea8094594d6074ad3c111455de3afcae72f6c40c6817e"},
{file = "labmaze-1.0.6-cp311-cp311-win_amd64.whl", hash = "sha256:2317e65e12fa3d1abecda7e0488dab15456cee8a2e717a586bfc8f02a91579e7"},
{file = "labmaze-1.0.6-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:e36b6fadcd78f22057b597c1c77823e806a0987b3bdfbf850e14b6b5b502075e"},
{file = "labmaze-1.0.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d1a4f8de29c2c3d7f14163759b69cd3f237093b85334c983619c1db5403a223b"},
{file = "labmaze-1.0.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a394f8bb857fcaa2884b809d63e750841c2662a106cfe8c045f2112d201ac7d5"},
{file = "labmaze-1.0.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d17abb69d4dfc56183afb5c317e8b2eaca0587abb3aabd2326efd3143c81f4e"},
{file = "labmaze-1.0.6-cp312-cp312-win_amd64.whl", hash = "sha256:5af997598cc46b1929d1c5a1febc32fd56c75874fe481a2a5982c65cee8450c9"},
{file = "labmaze-1.0.6-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:a4c5bc6e56baa55ce63b97569afec2f80cab0f6b952752a131e1f83eed190a53"},
{file = "labmaze-1.0.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3955f24fe5f708e1e97495b4cfe284b70ae4fd51be5e17b75a6fc04ffbd67bca"},
{file = "labmaze-1.0.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ed96ddc0bb8d66df36428c94db83949fd84a15867e8250763a4c5e3d82104c54"},
{file = "labmaze-1.0.6-cp37-cp37m-win_amd64.whl", hash = "sha256:3bd0458a29e55aa09f146e28a168d2e00b8ccf19e2259a3f71154cfff3536b1d"},
{file = "labmaze-1.0.6-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:33f5154edc83dff55a150e54b60c8582fdafc7ec45195049809cbcc01f5e8f34"},
{file = "labmaze-1.0.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0971055ef2a5f7d8517fdc42b67c057093698f1eb911f46faa7018867b73fcc9"},
{file = "labmaze-1.0.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de18d09680007302abf49111f3fe822d8435e4fbc4468b9ec07d50a78e267865"},
{file = "labmaze-1.0.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f18126066db2218a52853c7dd490b4c3d8129fc22eb3a47eb23007524b911d53"},
{file = "labmaze-1.0.6-cp38-cp38-win_amd64.whl", hash = "sha256:f9aef09a76877342bb4d634b7e05f43b038a49c4f34adfb8f1b8ac57c29472f2"},
{file = "labmaze-1.0.6-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:5dd28899418f1b8b1c7d1e1b40a4593150a7cfa95ca91e23860b9785b82cc0ee"},
{file = "labmaze-1.0.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:965569f37ee33090b4d4b3aa5aa7c9dcc4f62e2ae5d761e7f73ec76fc9d8aa96"},
{file = "labmaze-1.0.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05eccfa98c0e781bc9f939076ae600b2e25ca736e123f2a530606aedec3b531c"},
{file = "labmaze-1.0.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bee8c94e0fb3fc2d8180214947245c1d74a3489349a9da90b868296e77a521e9"},
{file = "labmaze-1.0.6-cp39-cp39-win_amd64.whl", hash = "sha256:d486e9ca3a335ad628e3bd48a09c42f1aa5f51040952ef0fe32507afedcd694b"},
{file = "labmaze-1.0.6.tar.gz", hash = "sha256:2e8de7094042a77d6972f1965cf5c9e8f971f1b34d225752f343190a825ebe73"},
]
[package.dependencies]
absl-py = "*"
numpy = ">=1.8.0"
setuptools = "!=50.0.0"
[[package]]
name = "lazy-loader"
version = "0.3"
@ -1076,6 +1099,99 @@ files = [
{file = "llvmlite-0.42.0.tar.gz", hash = "sha256:f92b09243c0cc3f457da8b983f67bd8e1295d0f5b3746c7a1861d7a99403854a"},
]
[[package]]
name = "lxml"
version = "5.1.0"
description = "Powerful and Pythonic XML processing library combining libxml2/libxslt with the ElementTree API."
optional = false
python-versions = ">=3.6"
files = [
{file = "lxml-5.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:704f5572ff473a5f897745abebc6df40f22d4133c1e0a1f124e4f2bd3330ff7e"},
{file = "lxml-5.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9d3c0f8567ffe7502d969c2c1b809892dc793b5d0665f602aad19895f8d508da"},
{file = "lxml-5.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5fcfbebdb0c5d8d18b84118842f31965d59ee3e66996ac842e21f957eb76138c"},
{file = "lxml-5.1.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2f37c6d7106a9d6f0708d4e164b707037b7380fcd0b04c5bd9cae1fb46a856fb"},
{file = "lxml-5.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2befa20a13f1a75c751f47e00929fb3433d67eb9923c2c0b364de449121f447c"},
{file = "lxml-5.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22b7ee4c35f374e2c20337a95502057964d7e35b996b1c667b5c65c567d2252a"},
{file = "lxml-5.1.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:bf8443781533b8d37b295016a4b53c1494fa9a03573c09ca5104550c138d5c05"},
{file = "lxml-5.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:82bddf0e72cb2af3cbba7cec1d2fd11fda0de6be8f4492223d4a268713ef2147"},
{file = "lxml-5.1.0-cp310-cp310-win32.whl", hash = "sha256:b66aa6357b265670bb574f050ffceefb98549c721cf28351b748be1ef9577d93"},
{file = "lxml-5.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:4946e7f59b7b6a9e27bef34422f645e9a368cb2be11bf1ef3cafc39a1f6ba68d"},
{file = "lxml-5.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:14deca1460b4b0f6b01f1ddc9557704e8b365f55c63070463f6c18619ebf964f"},
{file = "lxml-5.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ed8c3d2cd329bf779b7ed38db176738f3f8be637bb395ce9629fc76f78afe3d4"},
{file = "lxml-5.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:436a943c2900bb98123b06437cdd30580a61340fbdb7b28aaf345a459c19046a"},
{file = "lxml-5.1.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:acb6b2f96f60f70e7f34efe0c3ea34ca63f19ca63ce90019c6cbca6b676e81fa"},
{file = "lxml-5.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:af8920ce4a55ff41167ddbc20077f5698c2e710ad3353d32a07d3264f3a2021e"},
{file = "lxml-5.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7cfced4a069003d8913408e10ca8ed092c49a7f6cefee9bb74b6b3e860683b45"},
{file = "lxml-5.1.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:9e5ac3437746189a9b4121db2a7b86056ac8786b12e88838696899328fc44bb2"},
{file = "lxml-5.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f4c9bda132ad108b387c33fabfea47866af87f4ea6ffb79418004f0521e63204"},
{file = "lxml-5.1.0-cp311-cp311-win32.whl", hash = "sha256:bc64d1b1dab08f679fb89c368f4c05693f58a9faf744c4d390d7ed1d8223869b"},
{file = "lxml-5.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:a5ab722ae5a873d8dcee1f5f45ddd93c34210aed44ff2dc643b5025981908cda"},
{file = "lxml-5.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:9aa543980ab1fbf1720969af1d99095a548ea42e00361e727c58a40832439114"},
{file = "lxml-5.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6f11b77ec0979f7e4dc5ae081325a2946f1fe424148d3945f943ceaede98adb8"},
{file = "lxml-5.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a36c506e5f8aeb40680491d39ed94670487ce6614b9d27cabe45d94cd5d63e1e"},
{file = "lxml-5.1.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f643ffd2669ffd4b5a3e9b41c909b72b2a1d5e4915da90a77e119b8d48ce867a"},
{file = "lxml-5.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16dd953fb719f0ffc5bc067428fc9e88f599e15723a85618c45847c96f11f431"},
{file = "lxml-5.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:16018f7099245157564d7148165132c70adb272fb5a17c048ba70d9cc542a1a1"},
{file = "lxml-5.1.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:82cd34f1081ae4ea2ede3d52f71b7be313756e99b4b5f829f89b12da552d3aa3"},
{file = "lxml-5.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:19a1bc898ae9f06bccb7c3e1dfd73897ecbbd2c96afe9095a6026016e5ca97b8"},
{file = "lxml-5.1.0-cp312-cp312-win32.whl", hash = "sha256:13521a321a25c641b9ea127ef478b580b5ec82aa2e9fc076c86169d161798b01"},
{file = "lxml-5.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:1ad17c20e3666c035db502c78b86e58ff6b5991906e55bdbef94977700c72623"},
{file = "lxml-5.1.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:24ef5a4631c0b6cceaf2dbca21687e29725b7c4e171f33a8f8ce23c12558ded1"},
{file = "lxml-5.1.0-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8d2900b7f5318bc7ad8631d3d40190b95ef2aa8cc59473b73b294e4a55e9f30f"},
{file = "lxml-5.1.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:601f4a75797d7a770daed8b42b97cd1bb1ba18bd51a9382077a6a247a12aa38d"},
{file = "lxml-5.1.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4b68c961b5cc402cbd99cca5eb2547e46ce77260eb705f4d117fd9c3f932b95"},
{file = "lxml-5.1.0-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:afd825e30f8d1f521713a5669b63657bcfe5980a916c95855060048b88e1adb7"},
{file = "lxml-5.1.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:262bc5f512a66b527d026518507e78c2f9c2bd9eb5c8aeeb9f0eb43fcb69dc67"},
{file = "lxml-5.1.0-cp36-cp36m-win32.whl", hash = "sha256:e856c1c7255c739434489ec9c8aa9cdf5179785d10ff20add308b5d673bed5cd"},
{file = "lxml-5.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:c7257171bb8d4432fe9d6fdde4d55fdbe663a63636a17f7f9aaba9bcb3153ad7"},
{file = "lxml-5.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b9e240ae0ba96477682aa87899d94ddec1cc7926f9df29b1dd57b39e797d5ab5"},
{file = "lxml-5.1.0-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a96f02ba1bcd330807fc060ed91d1f7a20853da6dd449e5da4b09bfcc08fdcf5"},
{file = "lxml-5.1.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e3898ae2b58eeafedfe99e542a17859017d72d7f6a63de0f04f99c2cb125936"},
{file = "lxml-5.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61c5a7edbd7c695e54fca029ceb351fc45cd8860119a0f83e48be44e1c464862"},
{file = "lxml-5.1.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:3aeca824b38ca78d9ee2ab82bd9883083d0492d9d17df065ba3b94e88e4d7ee6"},
{file = "lxml-5.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8f52fe6859b9db71ee609b0c0a70fea5f1e71c3462ecf144ca800d3f434f0764"},
{file = "lxml-5.1.0-cp37-cp37m-win32.whl", hash = "sha256:d42e3a3fc18acc88b838efded0e6ec3edf3e328a58c68fbd36a7263a874906c8"},
{file = "lxml-5.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:eac68f96539b32fce2c9b47eb7c25bb2582bdaf1bbb360d25f564ee9e04c542b"},
{file = "lxml-5.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ae15347a88cf8af0949a9872b57a320d2605ae069bcdf047677318bc0bba45b1"},
{file = "lxml-5.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c26aab6ea9c54d3bed716b8851c8bfc40cb249b8e9880e250d1eddde9f709bf5"},
{file = "lxml-5.1.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:342e95bddec3a698ac24378d61996b3ee5ba9acfeb253986002ac53c9a5f6f84"},
{file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:725e171e0b99a66ec8605ac77fa12239dbe061482ac854d25720e2294652eeaa"},
{file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d184e0d5c918cff04cdde9dbdf9600e960161d773666958c9d7b565ccc60c45"},
{file = "lxml-5.1.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:98f3f020a2b736566c707c8e034945c02aa94e124c24f77ca097c446f81b01f1"},
{file = "lxml-5.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6d48fc57e7c1e3df57be5ae8614bab6d4e7b60f65c5457915c26892c41afc59e"},
{file = "lxml-5.1.0-cp38-cp38-win32.whl", hash = "sha256:7ec465e6549ed97e9f1e5ed51c657c9ede767bc1c11552f7f4d022c4df4a977a"},
{file = "lxml-5.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:b21b4031b53d25b0858d4e124f2f9131ffc1530431c6d1321805c90da78388d1"},
{file = "lxml-5.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:52427a7eadc98f9e62cb1368a5079ae826f94f05755d2d567d93ee1bc3ceb354"},
{file = "lxml-5.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6a2a2c724d97c1eb8cf966b16ca2915566a4904b9aad2ed9a09c748ffe14f969"},
{file = "lxml-5.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:843b9c835580d52828d8f69ea4302537337a21e6b4f1ec711a52241ba4a824f3"},
{file = "lxml-5.1.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9b99f564659cfa704a2dd82d0684207b1aadf7d02d33e54845f9fc78e06b7581"},
{file = "lxml-5.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f8b0c78e7aac24979ef09b7f50da871c2de2def043d468c4b41f512d831e912"},
{file = "lxml-5.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9bcf86dfc8ff3e992fed847c077bd875d9e0ba2fa25d859c3a0f0f76f07f0c8d"},
{file = "lxml-5.1.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:49a9b4af45e8b925e1cd6f3b15bbba2c81e7dba6dce170c677c9cda547411e14"},
{file = "lxml-5.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:280f3edf15c2a967d923bcfb1f8f15337ad36f93525828b40a0f9d6c2ad24890"},
{file = "lxml-5.1.0-cp39-cp39-win32.whl", hash = "sha256:ed7326563024b6e91fef6b6c7a1a2ff0a71b97793ac33dbbcf38f6005e51ff6e"},
{file = "lxml-5.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:8d7b4beebb178e9183138f552238f7e6613162a42164233e2bda00cb3afac58f"},
{file = "lxml-5.1.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:9bd0ae7cc2b85320abd5e0abad5ccee5564ed5f0cc90245d2f9a8ef330a8deae"},
{file = "lxml-5.1.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8c1d679df4361408b628f42b26a5d62bd3e9ba7f0c0e7969f925021554755aa"},
{file = "lxml-5.1.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:2ad3a8ce9e8a767131061a22cd28fdffa3cd2dc193f399ff7b81777f3520e372"},
{file = "lxml-5.1.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:304128394c9c22b6569eba2a6d98392b56fbdfbad58f83ea702530be80d0f9df"},
{file = "lxml-5.1.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d74fcaf87132ffc0447b3c685a9f862ffb5b43e70ea6beec2fb8057d5d2a1fea"},
{file = "lxml-5.1.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:8cf5877f7ed384dabfdcc37922c3191bf27e55b498fecece9fd5c2c7aaa34c33"},
{file = "lxml-5.1.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:877efb968c3d7eb2dad540b6cabf2f1d3c0fbf4b2d309a3c141f79c7e0061324"},
{file = "lxml-5.1.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f14a4fb1c1c402a22e6a341a24c1341b4a3def81b41cd354386dcb795f83897"},
{file = "lxml-5.1.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:25663d6e99659544ee8fe1b89b1a8c0aaa5e34b103fab124b17fa958c4a324a6"},
{file = "lxml-5.1.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:8b9f19df998761babaa7f09e6bc169294eefafd6149aaa272081cbddc7ba4ca3"},
{file = "lxml-5.1.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e53d7e6a98b64fe54775d23a7c669763451340c3d44ad5e3a3b48a1efbdc96f"},
{file = "lxml-5.1.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:c3cd1fc1dc7c376c54440aeaaa0dcc803d2126732ff5c6b68ccd619f2e64be4f"},
{file = "lxml-5.1.0.tar.gz", hash = "sha256:3eea6ed6e6c918e468e693c41ef07f3c3acc310b70ddd9cc72d9ef84bc9564ca"},
]
[package.extras]
cssselect = ["cssselect (>=0.7)"]
html5 = ["html5lib"]
htmlsoup = ["BeautifulSoup4"]
source = ["Cython (>=3.0.7)"]
[[package]]
name = "markupsafe"
version = "2.1.5"
@ -1188,42 +1304,40 @@ tests = ["pytest (>=4.6)"]
[[package]]
name = "mujoco"
version = "3.1.2"
version = "2.3.7"
description = "MuJoCo Physics Simulator"
optional = false
python-versions = ">=3.8"
files = [
{file = "mujoco-3.1.2-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:fe6b3542695a5363f348ee45625b3492734f29cdc9f493ca25eae719f974370e"},
{file = "mujoco-3.1.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f07e2d1f01f1401f1a503187016f8c017d9402618c659e1482243640a1e11288"},
{file = "mujoco-3.1.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:93863eccc9d77d96ce62dda2a6f61cbd880379e8d774f802568d64b9613fce39"},
{file = "mujoco-3.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3586c642390c16fef58b01a86071cec6814c471586e2f4115c3733c4aec64fb7"},
{file = "mujoco-3.1.2-cp310-cp310-win_amd64.whl", hash = "sha256:0da77394c664945b78f199c627b609fe091ec0c4641b9d8f713637344a17821a"},
{file = "mujoco-3.1.2-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:b6f12904d0478c191e4770ecf9006e20953f0488a2411a8ddc62592721c136dc"},
{file = "mujoco-3.1.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f69b8d42b50c10f8d12df4948fc9d4dd6706841e7b163c1d7ce83448965acb1c"},
{file = "mujoco-3.1.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:10119e39b1f45fb76b18bea242fea1d6ccf4b2285f8bd5e2cb1e2cbdeb69bdcd"},
{file = "mujoco-3.1.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a65868506dd45dddfe7be84857e57b49bc102334fc0439aa848a4d4d285d89b"},
{file = "mujoco-3.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:92bc73972e39539f23a05bb411c45f9be17191fe01171ac15ffafed381ee4366"},
{file = "mujoco-3.1.2-cp312-cp312-macosx_10_16_x86_64.whl", hash = "sha256:835d6b64ca4dc2f6a83291275fd48bd83edc888039d247958bf5b2c759db4340"},
{file = "mujoco-3.1.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ce94ca3cf14fc519981674c5b85f1055356dcdcd63bbc0ec6c340084438f27f"},
{file = "mujoco-3.1.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:250d9de4bd0d31fa4165faf01a1f838c429434f1263faacd95b977580f24eae7"},
{file = "mujoco-3.1.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ea009d10bbf0aba9bc835f051d25f07a2c3edbaa06627ac2348766a1f3760b9"},
{file = "mujoco-3.1.2-cp312-cp312-win_amd64.whl", hash = "sha256:a0460d2ebdad4926f48b8c774da473e011c3b3afd0ccb6b6be1087b788c34011"},
{file = "mujoco-3.1.2-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:4ca7cae89e258a338e02229edcf8f177b459ac5e9f859ffffa07fc2c9fcfb6aa"},
{file = "mujoco-3.1.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:33b4fe9b5f891b29ef0fc2b0b975bc3a8a4b87774eecaf4364a83ddc6a7762ba"},
{file = "mujoco-3.1.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ed230980f33bafaf1fa8b32ef25b82b069a245de15ee6ce7127e7e054cfad16"},
{file = "mujoco-3.1.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41cc610ac40f325c9d49d9885ac6cb61822ed938f6c23cb183b261a7a28472ca"},
{file = "mujoco-3.1.2-cp38-cp38-win_amd64.whl", hash = "sha256:90a172b904a6ca8e6a1be80ab7c393aaff7592843a2a6853a4f97a9204031c41"},
{file = "mujoco-3.1.2-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:93201291a0c5b573b4cbb19a6b08c99673f9fba167f402174eae5ffa23066d24"},
{file = "mujoco-3.1.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0398985bb28c2686cdeeaf4ded46e602a49ec12115ac77474144ca940e5261c5"},
{file = "mujoco-3.1.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2e76b5cb07ab3088c81966ac774d573df027fa5f4e78c20953a547528a2a698"},
{file = "mujoco-3.1.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd5c3f4ae858e812cb3f03332693bcdc343b2bce55b164523acf52dea2736c9e"},
{file = "mujoco-3.1.2-cp39-cp39-win_amd64.whl", hash = "sha256:ca25ff2646b06609526ef8681c0e123cd854a53c9ff23cb91dd5058a2794dab4"},
{file = "mujoco-3.1.2.tar.gz", hash = "sha256:53530bc1a91903f3fd4b1e99818cc38fbd9911700db29b2c9fc839f23bfacbb8"},
{file = "mujoco-2.3.7-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:e8714a5ff6a1561b364b7b4648d4c0c8d13e751874cf7401c309b9d23fa9598b"},
{file = "mujoco-2.3.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a934315f858a4e0c4b90a682fde519471cfdd7baa64435179da8cd20d4ae3f99"},
{file = "mujoco-2.3.7-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:36513024330f88b5f9a43558efef5692b33599bffd5141029b690a27918ffcbe"},
{file = "mujoco-2.3.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6d4eede8ba8210fbd3d3cd1dbf69e24dd1541aa74c5af5b8adbbbf65504b6dba"},
{file = "mujoco-2.3.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab85fafc9d5a091c712947573b7e694512d283876bf7f33ae3f8daad3a20c0db"},
{file = "mujoco-2.3.7-cp310-cp310-win_amd64.whl", hash = "sha256:f8b7e13fef8c813d91b78f975ed0815157692777907ffa4b4be53a4edb75019b"},
{file = "mujoco-2.3.7-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:779520216f72a8e370e3f0cdd71b45c3b7384c63331a3189194c930a3e7cff5c"},
{file = "mujoco-2.3.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9d4018053879016282d27ab7a91e292c72d44efb5a88553feacfe5b843dde103"},
{file = "mujoco-2.3.7-cp311-cp311-macosx_11_0_x86_64.whl", hash = "sha256:3149b16b8122ee62642474bfd2871064e8edc40235471cf5d84be3569afc0312"},
{file = "mujoco-2.3.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c08660a8d52ef3efde76095f0991e807703a950c1e882d2bcd984b9a846626f7"},
{file = "mujoco-2.3.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:426af8965f8636d94a0f75740c3024a62b3e585020ee817ef5208ec844a1ad94"},
{file = "mujoco-2.3.7-cp311-cp311-win_amd64.whl", hash = "sha256:215415a8e98a4b50625beae859079d5e0810b2039e50420f0ba81763c34abb59"},
{file = "mujoco-2.3.7-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:8b78d14f4c60cea3c58e046bd4de453fb5b9b33aca6a25fc91d39a53f3a5342a"},
{file = "mujoco-2.3.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5c6f5a51d6f537a4bf294cf73816f3a6384573f8f10a5452b044df2771412a96"},
{file = "mujoco-2.3.7-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:ea8911e6047f92d7d775701f37e4c093971b6def3160f01d0b6926e29a7e962e"},
{file = "mujoco-2.3.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7473a3de4dd1a8762d569ffb139196b4c5e7eca27d256df97b6cd4c66d2a09b2"},
{file = "mujoco-2.3.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40e7e2d8f93d2495ec74efec84e5118ecc6e1d85157a844789c73c9ac9a4e28e"},
{file = "mujoco-2.3.7-cp38-cp38-win_amd64.whl", hash = "sha256:720bc228a2023b3b0ed6af78f5b0f8ea36867be321d473321555c57dbf6e4e5b"},
{file = "mujoco-2.3.7-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:855e79686366442aa410246043b44f7d842d3900d68fe7e37feb42147db9d707"},
{file = "mujoco-2.3.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:98947f4a742d34d36f3c3f83e9167025bb0414bbaa4bd859b0673bdab9959963"},
{file = "mujoco-2.3.7-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:d42818f2ee5d1632dbce31d136ed5ff868db54b04e4e9aca0c5a3ac329f8a90f"},
{file = "mujoco-2.3.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9237e1ba14bced9449c31199e6d5be49547f3a4c99bc83b196af7ca45fd73b83"},
{file = "mujoco-2.3.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b728ea638245b150e2650c5433e6952e0ed3798c63e47e264574270caea2a3"},
{file = "mujoco-2.3.7-cp39-cp39-win_amd64.whl", hash = "sha256:9c721a5042b99d948d5f0296a534bcce3f142c777c4d7642f503a539513f3912"},
{file = "mujoco-2.3.7.tar.gz", hash = "sha256:422041f1ce37c6d151fbced1048df626837e94fe3cd9f813585907046336a7d0"},
]
[package.dependencies]
absl-py = "*"
etils = {version = "*", extras = ["epath"]}
glfw = "*"
numpy = "*"
pyopengl = "*"
@ -1519,13 +1633,13 @@ files = [
[[package]]
name = "nvidia-nvjitlink-cu12"
version = "12.3.101"
version = "12.4.99"
description = "Nvidia JIT LTO Library"
optional = false
python-versions = ">=3"
files = [
{file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:64335a8088e2b9d196ae8665430bc6a2b7e6ef2eb877a9c735c804bd4ff6467c"},
{file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-win_amd64.whl", hash = "sha256:1b2e317e437433753530792f13eece58f0aec21a2b05903be7bffe58a606cbd1"},
{file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_x86_64.whl", hash = "sha256:c6428836d20fe7e327191c175791d38570e10762edc588fb46749217cd444c74"},
{file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-win_amd64.whl", hash = "sha256:991905ffa2144cb603d8ca7962d75c35334ae82bf92820b6ba78157277da1ad2"},
]
[[package]]
@ -1580,13 +1694,13 @@ numpy = [
[[package]]
name = "packaging"
version = "23.2"
version = "24.0"
description = "Core utilities for Python packages"
optional = false
python-versions = ">=3.7"
files = [
{file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"},
{file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"},
{file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"},
{file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"},
]
[[package]]
@ -2016,6 +2130,20 @@ files = [
{file = "PyOpenGL-3.1.7.tar.gz", hash = "sha256:eef31a3888e6984fd4d8e6c9961b184c9813ca82604d37fe3da80eb000a76c86"},
]
[[package]]
name = "pyparsing"
version = "3.1.2"
description = "pyparsing module - Classes and methods to define and execute parsing grammars"
optional = false
python-versions = ">=3.6.8"
files = [
{file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"},
{file = "pyparsing-3.1.2.tar.gz", hash = "sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad"},
]
[package.extras]
diagrams = ["jinja2", "railroad-diagrams"]
[[package]]
name = "pysocks"
version = "1.7.1"
@ -2030,13 +2158,13 @@ files = [
[[package]]
name = "pytest"
version = "8.1.0"
version = "8.1.1"
description = "pytest: simple powerful testing with Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "pytest-8.1.0-py3-none-any.whl", hash = "sha256:ee32db7af8de4629a455806befa90559f307424c07b8413ccfc30bf5b221dd7e"},
{file = "pytest-8.1.0.tar.gz", hash = "sha256:f8fa04ab8f98d185113ae60ea6d79c22f8143b14bc1caeced44a0ab844928323"},
{file = "pytest-8.1.1-py3-none-any.whl", hash = "sha256:2a8386cfc11fa9d2c50ee7b2a57e7d898ef90470a7a34c4b949ff59662bb78b7"},
{file = "pytest-8.1.1.tar.gz", hash = "sha256:ac978141a75948948817d360297b7aae0fcb9d6ff6bc9ec6d514b85d5a65c044"},
]
[package.dependencies]
@ -2052,13 +2180,13 @@ testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygm
[[package]]
name = "python-dateutil"
version = "2.8.2"
version = "2.9.0.post0"
description = "Extensions to the standard Python datetime module"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
files = [
{file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
{file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
{file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"},
{file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"},
]
[package.dependencies]
@ -2483,13 +2611,13 @@ test = ["asv", "gmpy2", "hypothesis", "mpmath", "pooch", "pytest", "pytest-cov",
[[package]]
name = "sentry-sdk"
version = "1.40.6"
version = "1.41.0"
description = "Python client for Sentry (https://sentry.io)"
optional = false
python-versions = "*"
files = [
{file = "sentry-sdk-1.40.6.tar.gz", hash = "sha256:f143f3fb4bb57c90abef6e2ad06b5f6f02b2ca13e4060ec5c0549c7a9ccce3fa"},
{file = "sentry_sdk-1.40.6-py2.py3-none-any.whl", hash = "sha256:becda09660df63e55f307570e9817c664392655a7328bbc414b507e9cb874c67"},
{file = "sentry-sdk-1.41.0.tar.gz", hash = "sha256:4f2d6c43c07925d8cd10dfbd0970ea7cb784f70e79523cca9dbcd72df38e5a46"},
{file = "sentry_sdk-1.41.0-py2.py3-none-any.whl", hash = "sha256:be4f8f4b29a80b6a3b71f0f31487beb9e296391da20af8504498a328befed53f"},
]
[package.dependencies]
@ -2515,7 +2643,7 @@ huey = ["huey (>=2)"]
loguru = ["loguru (>=0.5)"]
opentelemetry = ["opentelemetry-distro (>=0.35b0)"]
opentelemetry-experimental = ["opentelemetry-distro (>=0.40b0,<1.0)", "opentelemetry-instrumentation-aiohttp-client (>=0.40b0,<1.0)", "opentelemetry-instrumentation-django (>=0.40b0,<1.0)", "opentelemetry-instrumentation-fastapi (>=0.40b0,<1.0)", "opentelemetry-instrumentation-flask (>=0.40b0,<1.0)", "opentelemetry-instrumentation-requests (>=0.40b0,<1.0)", "opentelemetry-instrumentation-sqlite3 (>=0.40b0,<1.0)", "opentelemetry-instrumentation-urllib (>=0.40b0,<1.0)"]
pure-eval = ["asttokens", "executing", "pure_eval"]
pure-eval = ["asttokens", "executing", "pure-eval"]
pymongo = ["pymongo (>=3.1)"]
pyspark = ["pyspark (>=2.4.4)"]
quart = ["blinker (>=1.1)", "quart (>=0.16.1)"]
@ -2769,7 +2897,7 @@ tests = ["pytest", "pytest-benchmark", "pytest-instafail", "pytest-rerunfailures
type = "git"
url = "https://github.com/pytorch/tensordict"
reference = "HEAD"
resolved_reference = "551331d83e2979dd4505db1f49895740e6e5c95f"
resolved_reference = "ed22554d6860731610df784b2f5d09f31d3dbc7a"
[[package]]
name = "termcolor"
@ -2888,13 +3016,13 @@ tensordict = ">=0.4.0"
torch = ">=2.1.0"
[package.extras]
all = ["ale-py", "atari-py", "dm-control", "git", "gym", "gym[accept-rom-license]", "gymnasium", "h5py", "huggingface-hub", "hydra-core (>=1.1)", "hydra-submitit-launcher", "minari", "moviepy", "mujoco", "pandas", "pettingzoo (>=1.24.1)", "pillow", "pygame", "pytest", "pytest-instafail", "pyyaml", "requests", "scikit-learn", "scipy", "tensorboard", "torchsnapshot", "torchvision", "tqdm", "vmas (>=1.2.10)", "wandb"]
all = ["ale-py", "atari-py", "dm_control", "git", "gym", "gym[accept-rom-license]", "gymnasium", "h5py", "huggingface_hub", "hydra-core (>=1.1)", "hydra-submitit-launcher", "minari", "moviepy", "mujoco", "pandas", "pettingzoo (>=1.24.1)", "pillow", "pygame", "pytest", "pytest-instafail", "pyyaml", "requests", "scikit-learn", "scipy", "tensorboard", "torchsnapshot", "torchvision", "tqdm", "vmas (>=1.2.10)", "wandb"]
atari = ["ale-py", "atari-py", "gym", "gym[accept-rom-license]", "pygame"]
checkpointing = ["torchsnapshot"]
dm-control = ["dm-control"]
dm-control = ["dm_control"]
gym-continuous = ["gymnasium", "mujoco"]
marl = ["pettingzoo (>=1.24.1)", "vmas (>=1.2.10)"]
offline-data = ["h5py", "huggingface-hub", "minari", "pandas", "pillow", "requests", "scikit-learn", "torchvision", "tqdm"]
offline-data = ["h5py", "huggingface_hub", "minari", "pandas", "pillow", "requests", "scikit-learn", "torchvision", "tqdm"]
rendering = ["moviepy"]
tests = ["pytest", "pytest-instafail", "pyyaml", "scipy"]
utils = ["git", "hydra-core (>=1.1)", "hydra-submitit-launcher", "tensorboard", "tqdm", "wandb"]
@ -3051,13 +3179,13 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess
[[package]]
name = "wandb"
version = "0.16.3"
version = "0.16.4"
description = "A CLI and library for interacting with the Weights & Biases API."
optional = false
python-versions = ">=3.7"
files = [
{file = "wandb-0.16.3-py3-none-any.whl", hash = "sha256:b8907ddd775c27dc6c12687386a86b5d6acf291060f9ae680bbc61cc8fc03237"},
{file = "wandb-0.16.3.tar.gz", hash = "sha256:d789acda32053b18b7a160d0595837e45a3c8a79d25e1fe1f051875303f480ec"},
{file = "wandb-0.16.4-py3-none-any.whl", hash = "sha256:bb9eb5aa2c2c85e11c76040c4271366f54d4975167aa6320ba86c3f2d97fe5fa"},
{file = "wandb-0.16.4.tar.gz", hash = "sha256:8752c67d1347a4c29777e64dc1e1a742a66c5ecde03aebadf2b0d62183fa307c"},
]
[package.dependencies]
@ -3078,8 +3206,9 @@ async = ["httpx (>=0.23.0)"]
aws = ["boto3"]
azure = ["azure-identity", "azure-storage-blob"]
gcp = ["google-cloud-storage"]
importers = ["filelock", "mlflow", "polars", "rich", "tenacity"]
kubeflow = ["google-cloud-storage", "kubernetes", "minio", "sh"]
launch = ["PyYAML (>=6.0.0)", "awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "pydantic", "typing-extensions"]
launch = ["PyYAML (>=6.0.0)", "awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "pydantic", "tomli", "typing-extensions"]
media = ["bokeh", "moviepy", "numpy", "pillow", "plotly (>=5.18.0)", "rdkit-pypi", "soundfile"]
models = ["cloudpickle"]
perf = ["orjson"]
@ -3088,13 +3217,13 @@ sweeps = ["sweeps (>=0.2.0)"]
[[package]]
name = "zarr"
version = "2.17.0"
version = "2.17.1"
description = "An implementation of chunked, compressed, N-dimensional arrays for Python"
optional = false
python-versions = ">=3.9"
files = [
{file = "zarr-2.17.0-py3-none-any.whl", hash = "sha256:d287cb61019c4a0a0f386f76eeaa7f0b1160b1cb90cf96173a4b6cbc135df6e1"},
{file = "zarr-2.17.0.tar.gz", hash = "sha256:6390a2b8af31babaab4c963efc45bf1da7f9500c9aafac193f84cf019a7c66b0"},
{file = "zarr-2.17.1-py3-none-any.whl", hash = "sha256:e25df2741a6e92645f3890f30f3136d5b57a0f8f831094b024bbcab5f2797bc7"},
{file = "zarr-2.17.1.tar.gz", hash = "sha256:564b3aa072122546fe69a0fa21736f466b20fad41754334b62619f088ce46261"},
]
[package.dependencies]
@ -3125,4 +3254,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "c4d83579aed1c8c2e54cad7c8ec81b95a09ab8faff74fc9a4cb20bd00e4ddec6"
content-hash = "3d82309a7b2388d774b56ceb6f6906ef0732d8cedda0d76cc84a30e239949be8"

View File

@ -42,13 +42,14 @@ mpmath = "^1.3.0"
torch = "^2.2.1"
tensordict = {git = "https://github.com/pytorch/tensordict"}
torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"}
mujoco = "^3.1.2"
mujoco = "2.3.7"
mujoco-py = "^2.1.2.14"
gym = "^0.26.2"
opencv-python = "^4.9.0.80"
diffusers = "^0.26.3"
torchvision = "^0.17.1"
h5py = "^3.10.0"
dm-control = "1.0.14"
[tool.poetry.group.dev.dependencies]

View File

@ -17,6 +17,7 @@ apptainer exec --nv \
~/apptainer/nvidia_cuda:12.2.2-devel-ubuntu22.04.sif $SHELL
source ~/.bashrc
conda activate fowm
#conda activate fowm
conda activate lerobot
srun $CMD

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d789deddb081a9f4b626342391de8f48949d38fb5fdead87b5c0737b46c0877a
size 2800

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5
size 400

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c202d9cfc7858fd49d522047e16948359bbbb2eda2d3825d552e45a78d5f8585
size 400

View File

@ -0,0 +1 @@
{"action": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51
size 50

View File

@ -0,0 +1 @@
{"done": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}

View File

@ -0,0 +1 @@
{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5c632e3cb06be729e5d673e3ecca1d6f6527b0f48cfe3dc03d7eea4f9eb3bbd7
size 46080000

View File

@ -0,0 +1 @@
{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e231f2e07e1cd030137ea2e938b570b112db2c694c6d21b37ceb8f8559e19088
size 2800

View File

@ -0,0 +1 @@
{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a1ba64c89f4fcf9135fe34c26abf582dd5f0d573506db5c96af3ffe40a52c818
size 46080000

View File

@ -0,0 +1 @@
{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:85405686bc065c6ab6c915907920a0391a57cf097b74de058a8c30be0548ade5
size 2800

Binary file not shown.

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1f5fe053b760e8471885b82c10f4a6ea40874098036337ae5cc300c4775546be
size 2800

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5
size 400

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c202d9cfc7858fd49d522047e16948359bbbb2eda2d3825d552e45a78d5f8585
size 400

View File

@ -0,0 +1 @@
{"action": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51
size 50

View File

@ -0,0 +1 @@
{"done": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}

View File

@ -0,0 +1 @@
{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:daed2bb10498ba2557983d0d7e89399882fea7585e7ceff910e23c621bfdbf88
size 46080000

View File

@ -0,0 +1 @@
{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bbad0302af70112ee312efe0eb0f44a2f1c8f6c5ef82ea4fb34625cdafbef057
size 2800

View File

@ -0,0 +1 @@
{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:aba55ebb9dd004bf68444b9ebf024ed7713436099c06a0b8e541100ecbc69290
size 46080000

View File

@ -0,0 +1 @@
{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:dd4e7e14abf57561ca9839c910581266be90956e41bfb3bb21362ea0c321e77d
size 2800

Binary file not shown.

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:14fed0eed3d529a8ac0dd25a6d41585020772d02f9137fc9d604713b2f0f7076
size 2800

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5
size 400

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c202d9cfc7858fd49d522047e16948359bbbb2eda2d3825d552e45a78d5f8585
size 400

View File

@ -0,0 +1 @@
{"action": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51
size 50

View File

@ -0,0 +1 @@
{"done": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}

View File

@ -0,0 +1 @@
{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2f713ea7fc19e592ea409a5e0bdfde403e5b86f834cbabe3463b791e8437fafc
size 46080000

View File

@ -0,0 +1 @@
{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8c103e2c9d63c9f7cf9645bd24d9a2c4e8e08825dc75e230ebc793b8f9c213b0
size 2800

View File

@ -0,0 +1 @@
{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7dbf4aa01b184d0eaa21ea999078d7cff86e1ca484a109614176fdc49f1ee05c
size 46080000

View File

@ -0,0 +1 @@
{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4fa0b9c870d4615037b6fee9e9e85e54d84352e173f2c7c1035232272fe2a3dd
size 2800

Binary file not shown.

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c0e199a82e2b7462e84406dbced5448a99f1dad9ce172771dfc3feb4b8597115
size 2800

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5
size 400

Some files were not shown because too many files have changed in this diff Show More