Merge pull request #11 from Cadene/user/rcadene/2024_03_06_aloha_env
[WIP] Add Aloha env and ACT policy
This commit is contained in:
commit
8c56770318
|
@ -18,6 +18,12 @@ jobs:
|
|||
env:
|
||||
POETRY_VERSION: 1.8.1
|
||||
DATA_DIR: tests/data
|
||||
TMPDIR: ~/tmp
|
||||
TEMP: ~/tmp
|
||||
TMP: ~/tmp
|
||||
PYOPENGL_PLATFORM: egl
|
||||
MUJOCO_GL: egl
|
||||
LEROBOT_TESTS_DEVICE: cpu
|
||||
steps:
|
||||
#----------------------------------------------
|
||||
# check-out repo and set-up python
|
||||
|
@ -26,11 +32,13 @@ jobs:
|
|||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Set up python
|
||||
id: setup-python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
#----------------------------------------------
|
||||
# install & configure poetry
|
||||
#----------------------------------------------
|
||||
|
@ -40,6 +48,7 @@ jobs:
|
|||
with:
|
||||
path: ~/.local # the path depends on the OS
|
||||
key: poetry-${{ env.POETRY_VERSION }} # increment to reset cache
|
||||
|
||||
- name: Install Poetry
|
||||
if: steps.restore-poetry-cache.outputs.cache-hit != 'true'
|
||||
uses: snok/install-poetry@v1
|
||||
|
@ -47,6 +56,7 @@ jobs:
|
|||
version: ${{ env.POETRY_VERSION }}
|
||||
virtualenvs-create: true
|
||||
installer-parallel: true
|
||||
|
||||
- name: Save cached Poetry installation
|
||||
if: |
|
||||
steps.restore-poetry-cache.outputs.cache-hit != 'true' &&
|
||||
|
@ -56,8 +66,10 @@ jobs:
|
|||
with:
|
||||
path: ~/.local # the path depends on the OS
|
||||
key: poetry-${{ env.POETRY_VERSION }} # increment to reset cache
|
||||
|
||||
- name: Configure Poetry
|
||||
run: poetry config virtualenvs.in-project true
|
||||
|
||||
#----------------------------------------------
|
||||
# install dependencies
|
||||
#----------------------------------------------
|
||||
|
@ -67,9 +79,21 @@ jobs:
|
|||
with:
|
||||
path: .venv
|
||||
key: venv-${{ steps.setup-python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ hashFiles('**/poetry.lock') }}
|
||||
|
||||
- name: Info
|
||||
run: |
|
||||
sudo du -sh /tmp
|
||||
sudo df -h
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.restore-dependencies-cache.outputs.cache-hit != 'true'
|
||||
run: poetry install --no-interaction --no-root
|
||||
run: |
|
||||
mkdir ~/tmp
|
||||
echo $TMPDIR
|
||||
echo $TEMP
|
||||
echo $TMP
|
||||
poetry install --no-interaction --no-root
|
||||
|
||||
- name: Save cached venv
|
||||
if: |
|
||||
steps.restore-dependencies-cache.outputs.cache-hit != 'true' &&
|
||||
|
@ -79,11 +103,16 @@ jobs:
|
|||
with:
|
||||
path: .venv
|
||||
key: venv-${{ steps.setup-python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ hashFiles('**/poetry.lock') }}
|
||||
|
||||
- name: Install libegl1-mesa-dev (to use MUJOCO_GL=egl)
|
||||
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
|
||||
|
||||
#----------------------------------------------
|
||||
# install project
|
||||
#----------------------------------------------
|
||||
- name: Install project
|
||||
run: poetry install --no-interaction
|
||||
|
||||
#----------------------------------------------
|
||||
# run tests
|
||||
#----------------------------------------------
|
||||
|
@ -91,6 +120,7 @@ jobs:
|
|||
run: |
|
||||
source .venv/bin/activate
|
||||
pytest tests
|
||||
|
||||
- name: Test train pusht end-to-end
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
|
@ -104,6 +134,7 @@ jobs:
|
|||
save_model=true \
|
||||
save_freq=1 \
|
||||
hydra.run.dir=tests/outputs/
|
||||
|
||||
- name: Test eval pusht end-to-end
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
|
|
17
README.md
17
README.md
|
@ -59,19 +59,10 @@ env=pusht
|
|||
|
||||
## TODO
|
||||
|
||||
- [x] priority update doesn't match FOWM or original paper
|
||||
- [x] self.step=100000 should be updated at every step to adjust to the horizon of the planner
|
||||
- [ ] prefetch replay buffer to speedup training
|
||||
- [ ] parallelize env to speed up eval
|
||||
- [ ] clean checkpointing / loading
|
||||
- [ ] clean logging
|
||||
- [ ] clean config
|
||||
- [ ] clean hyperparameter tuning
|
||||
- [ ] add pusht
|
||||
- [ ] add aloha
|
||||
- [ ] add act
|
||||
- [ ] add diffusion
|
||||
- [ ] add aloha 2
|
||||
If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/users/Cadene/projects/1)
|
||||
|
||||
Ask [Remi Cadene](re.cadene@gmail.com) for access if needed.
|
||||
|
||||
|
||||
## Profile
|
||||
|
||||
|
|
|
@ -81,7 +81,10 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
|||
def set_transform(self, transform):
|
||||
if not isinstance(transform, Compose):
|
||||
# required since torchrl calls `len(self._transform)` downstream
|
||||
self._transform = Compose(transform)
|
||||
if isinstance(transform, list):
|
||||
self._transform = Compose(*transform)
|
||||
else:
|
||||
self._transform = Compose(transform)
|
||||
else:
|
||||
self._transform = transform
|
||||
|
||||
|
|
|
@ -73,11 +73,11 @@ def download(data_dir, dataset_id):
|
|||
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
gdown.download_folder(FOLDER_URLS[dataset_id], output=data_dir)
|
||||
gdown.download_folder(FOLDER_URLS[dataset_id], output=str(data_dir))
|
||||
|
||||
# because of the 50 files limit per directory, two files episode 48 and 49 were missing
|
||||
gdown.download(EP48_URLS[dataset_id], output=data_dir / "episode_48.hdf5", fuzzy=True)
|
||||
gdown.download(EP49_URLS[dataset_id], output=data_dir / "episode_49.hdf5", fuzzy=True)
|
||||
gdown.download(EP48_URLS[dataset_id], output=str(data_dir / "episode_48.hdf5"), fuzzy=True)
|
||||
gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True)
|
||||
|
||||
|
||||
class AlohaExperienceReplay(AbstractExperienceReplay):
|
||||
|
@ -124,9 +124,6 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
|
|||
def image_keys(self) -> list:
|
||||
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]
|
||||
|
||||
# def _is_downloaded(self) -> bool:
|
||||
# return False
|
||||
|
||||
def _download_and_preproc(self):
|
||||
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
|
||||
if not raw_dir.is_dir():
|
||||
|
|
|
@ -5,7 +5,7 @@ from pathlib import Path
|
|||
import torch
|
||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
|
||||
|
||||
from lerobot.common.envs.transforms import NormalizeTransform
|
||||
from lerobot.common.envs.transforms import NormalizeTransform, Prod
|
||||
|
||||
DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
|
||||
|
||||
|
@ -84,6 +84,16 @@ def make_offline_buffer(
|
|||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
)
|
||||
|
||||
if cfg.policy.name == "tdmpc":
|
||||
img_keys = []
|
||||
for key in offline_buffer.image_keys:
|
||||
img_keys.append(("next", *key))
|
||||
img_keys += offline_buffer.image_keys
|
||||
else:
|
||||
img_keys = offline_buffer.image_keys
|
||||
|
||||
transforms = [Prod(in_keys=img_keys, prod=1 / 255)]
|
||||
|
||||
if normalize:
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
|
||||
stats = offline_buffer.compute_or_load_stats()
|
||||
|
@ -92,11 +102,10 @@ def make_offline_buffer(
|
|||
in_keys = [("observation", "state"), ("action")]
|
||||
|
||||
if cfg.policy.name == "tdmpc":
|
||||
for key in offline_buffer.image_keys:
|
||||
# TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc
|
||||
in_keys.append(key)
|
||||
# since we use next observations in tdmpc
|
||||
in_keys.append(("next", *key))
|
||||
# TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now
|
||||
in_keys += img_keys
|
||||
# TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now.
|
||||
in_keys += [("next", *key) for key in img_keys]
|
||||
in_keys.append(("next", "observation", "state"))
|
||||
|
||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||
|
@ -106,8 +115,11 @@ def make_offline_buffer(
|
|||
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
|
||||
transform = NormalizeTransform(stats, in_keys, mode="min_max")
|
||||
offline_buffer.set_transform(transform)
|
||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||
transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode))
|
||||
|
||||
offline_buffer.set_transform(transforms)
|
||||
|
||||
if not overwrite_sampler:
|
||||
index = torch.arange(0, offline_buffer.num_samples, 1)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import pickle
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
|
@ -15,6 +16,22 @@ from torchrl.data.replay_buffers.writers import Writer
|
|||
from lerobot.common.datasets.abstract import AbstractExperienceReplay
|
||||
|
||||
|
||||
def download():
|
||||
raise NotImplementedError()
|
||||
import gdown
|
||||
|
||||
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
||||
download_path = "data.zip"
|
||||
gdown.download(url, download_path, quiet=False)
|
||||
print("Extracting...")
|
||||
with zipfile.ZipFile(download_path, "r") as zip_f:
|
||||
for member in zip_f.namelist():
|
||||
if member.startswith("data/xarm") and member.endswith(".pkl"):
|
||||
print(member)
|
||||
zip_f.extract(member=member)
|
||||
Path(download_path).unlink()
|
||||
|
||||
|
||||
class SimxarmExperienceReplay(AbstractExperienceReplay):
|
||||
available_datasets = [
|
||||
"xarm_lift_medium",
|
||||
|
@ -48,8 +65,8 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
|
|||
)
|
||||
|
||||
def _download_and_preproc(self):
|
||||
# download
|
||||
# TODO(rcadene)
|
||||
# TODO(rcadene): finish download
|
||||
download()
|
||||
|
||||
dataset_path = self.data_dir / "buffer.pkl"
|
||||
print(f"Using offline dataset '{dataset_path}'")
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
import abc
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
from tensordict import TensorDict
|
||||
from torchrl.envs import EnvBase
|
||||
|
||||
|
||||
class AbstractEnv(EnvBase):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
frame_skip: int = 1,
|
||||
from_pixels: bool = False,
|
||||
pixels_only: bool = False,
|
||||
image_size=None,
|
||||
seed=1337,
|
||||
device="cpu",
|
||||
num_prev_obs=1,
|
||||
num_prev_action=0,
|
||||
):
|
||||
super().__init__(device=device, batch_size=[])
|
||||
self.task = task
|
||||
self.frame_skip = frame_skip
|
||||
self.from_pixels = from_pixels
|
||||
self.pixels_only = pixels_only
|
||||
self.image_size = image_size
|
||||
self.num_prev_obs = num_prev_obs
|
||||
self.num_prev_action = num_prev_action
|
||||
self._rendering_hooks = []
|
||||
|
||||
if pixels_only:
|
||||
assert from_pixels
|
||||
if from_pixels:
|
||||
assert image_size
|
||||
|
||||
self._make_env()
|
||||
self._make_spec()
|
||||
self._current_seed = self.set_seed(seed)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
|
||||
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
|
||||
if self.num_prev_action > 0:
|
||||
raise NotImplementedError()
|
||||
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
|
||||
|
||||
def register_rendering_hook(self, func):
|
||||
self._rendering_hooks.append(func)
|
||||
|
||||
def call_rendering_hooks(self):
|
||||
for func in self._rendering_hooks:
|
||||
func(self)
|
||||
|
||||
def reset_rendering_hooks(self):
|
||||
self._rendering_hooks = []
|
||||
|
||||
@abc.abstractmethod
|
||||
def render(self, mode="rgb_array", width=640, height=480):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _step(self, tensordict: TensorDict):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _make_env(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _make_spec(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _set_seed(self, seed: Optional[int]):
|
||||
raise NotImplementedError()
|
|
@ -0,0 +1,59 @@
|
|||
<mujoco>
|
||||
<include file="scene.xml"/>
|
||||
<include file="vx300s_dependencies.xml"/>
|
||||
|
||||
<equality>
|
||||
<weld body1="mocap_left" body2="vx300s_left/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
|
||||
<weld body1="mocap_right" body2="vx300s_right/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
|
||||
</equality>
|
||||
|
||||
|
||||
<worldbody>
|
||||
<include file="vx300s_left.xml" />
|
||||
<include file="vx300s_right.xml" />
|
||||
|
||||
<body mocap="true" name="mocap_left" pos="0.095 0.50 0.425">
|
||||
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_left_site1" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_left_site2" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_left_site3" rgba="1 0 0 1"/>
|
||||
</body>
|
||||
<body mocap="true" name="mocap_right" pos="-0.095 0.50 0.425">
|
||||
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_right_site1" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_right_site2" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_right_site3" rgba="1 0 0 1"/>
|
||||
</body>
|
||||
|
||||
<body name="peg" pos="0.2 0.5 0.05">
|
||||
<joint name="red_peg_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
<body name="socket" pos="-0.2 0.5 0.05">
|
||||
<joint name="blue_socket_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<!-- <geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg_ref" rgba="1 0 0 1" />-->
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 -0.02" size="0.06 0.018 0.002" type="box" name="socket-1" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 0.02" size="0.06 0.018 0.002" type="box" name="socket-2" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0.02 0" size="0.06 0.002 0.018" type="box" name="socket-3" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 -0.02 0" size="0.06 0.002 0.018" type="box" name="socket-4" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.04 0.01 0.01" type="box" name="pin" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
</worldbody>
|
||||
|
||||
<actuator>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
|
||||
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
|
||||
|
||||
</actuator>
|
||||
|
||||
<keyframe>
|
||||
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0 -0.2 0.5 0.05 1 0 0 0"/>
|
||||
</keyframe>
|
||||
|
||||
|
||||
</mujoco>
|
|
@ -0,0 +1,48 @@
|
|||
<mujoco>
|
||||
<include file="scene.xml"/>
|
||||
<include file="vx300s_dependencies.xml"/>
|
||||
|
||||
<equality>
|
||||
<weld body1="mocap_left" body2="vx300s_left/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
|
||||
<weld body1="mocap_right" body2="vx300s_right/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
|
||||
</equality>
|
||||
|
||||
|
||||
<worldbody>
|
||||
<include file="vx300s_left.xml" />
|
||||
<include file="vx300s_right.xml" />
|
||||
|
||||
<body mocap="true" name="mocap_left" pos="0.095 0.50 0.425">
|
||||
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_left_site1" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_left_site2" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_left_site3" rgba="1 0 0 1"/>
|
||||
</body>
|
||||
<body mocap="true" name="mocap_right" pos="-0.095 0.50 0.425">
|
||||
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_right_site1" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_right_site2" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_right_site3" rgba="1 0 0 1"/>
|
||||
</body>
|
||||
|
||||
<body name="box" pos="0.2 0.5 0.05">
|
||||
<joint name="red_box_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.02 0.02 0.02" type="box" name="red_box" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
</worldbody>
|
||||
|
||||
<actuator>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
|
||||
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
|
||||
|
||||
</actuator>
|
||||
|
||||
<keyframe>
|
||||
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0"/>
|
||||
</keyframe>
|
||||
|
||||
|
||||
</mujoco>
|
|
@ -0,0 +1,53 @@
|
|||
<mujoco>
|
||||
<include file="scene.xml"/>
|
||||
<include file="vx300s_dependencies.xml"/>
|
||||
<worldbody>
|
||||
<include file="vx300s_left.xml" />
|
||||
<include file="vx300s_right.xml" />
|
||||
|
||||
<body name="peg" pos="0.2 0.5 0.05">
|
||||
<joint name="red_peg_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
<body name="socket" pos="-0.2 0.5 0.05">
|
||||
<joint name="blue_socket_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<!-- <geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg_ref" rgba="1 0 0 1" />-->
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 -0.02" size="0.06 0.018 0.002" type="box" name="socket-1" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 0.02" size="0.06 0.018 0.002" type="box" name="socket-2" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0.02 0" size="0.06 0.002 0.018" type="box" name="socket-3" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 -0.02 0" size="0.06 0.002 0.018" type="box" name="socket-4" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.04 0.01 0.01" type="box" name="pin" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
</worldbody>
|
||||
|
||||
<actuator>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_left/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_left/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_left/wrist_angle" kp="50" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/wrist_rotate" kp="20" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
|
||||
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_right/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_right/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_right/wrist_angle" kp="50" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/wrist_rotate" kp="20" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
|
||||
|
||||
</actuator>
|
||||
|
||||
<keyframe>
|
||||
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0 -0.2 0.5 0.05 1 0 0 0"/>
|
||||
</keyframe>
|
||||
|
||||
|
||||
</mujoco>
|
|
@ -0,0 +1,42 @@
|
|||
<mujoco>
|
||||
<include file="scene.xml"/>
|
||||
<include file="vx300s_dependencies.xml"/>
|
||||
<worldbody>
|
||||
<include file="vx300s_left.xml" />
|
||||
<include file="vx300s_right.xml" />
|
||||
|
||||
<body name="box" pos="0.2 0.5 0.05">
|
||||
<joint name="red_box_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.02 0.02 0.02" type="box" name="red_box" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
</worldbody>
|
||||
|
||||
<actuator>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_left/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_left/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_left/wrist_angle" kp="50" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/wrist_rotate" kp="20" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
|
||||
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_right/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_right/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_right/wrist_angle" kp="50" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/wrist_rotate" kp="20" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
|
||||
|
||||
</actuator>
|
||||
|
||||
<keyframe>
|
||||
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0"/>
|
||||
</keyframe>
|
||||
|
||||
|
||||
</mujoco>
|
|
@ -0,0 +1,38 @@
|
|||
<mujocoinclude>
|
||||
<!-- <option timestep='0.0025' iterations="50" tolerance="1e-10" solver="Newton" jacobian="dense" cone="elliptic"/>-->
|
||||
|
||||
<asset>
|
||||
<mesh file="tabletop.stl" name="tabletop" scale="0.001 0.001 0.001"/>
|
||||
</asset>
|
||||
|
||||
<visual>
|
||||
<map fogstart="1.5" fogend="5" force="0.1" znear="0.1"/>
|
||||
<quality shadowsize="4096" offsamples="4"/>
|
||||
<headlight ambient="0.4 0.4 0.4"/>
|
||||
</visual>
|
||||
|
||||
<worldbody>
|
||||
<light castshadow="false" directional='true' diffuse='.3 .3 .3' specular='0.3 0.3 0.3' pos='-1 -1 1'
|
||||
dir='1 1 -1'/>
|
||||
<light directional='true' diffuse='.3 .3 .3' specular='0.3 0.3 0.3' pos='1 -1 1' dir='-1 1 -1'/>
|
||||
<light castshadow="false" directional='true' diffuse='.3 .3 .3' specular='0.3 0.3 0.3' pos='0 1 1'
|
||||
dir='0 -1 -1'/>
|
||||
|
||||
<body name="table" pos="0 .6 0">
|
||||
<geom group="1" mesh="tabletop" pos="0 0 0" type="mesh" conaffinity="1" contype="1" name="table" rgba="0.2 0.2 0.2 1" />
|
||||
</body>
|
||||
<body name="midair" pos="0 .6 0.2">
|
||||
<site pos="0 0 0" size="0.01" type="sphere" name="midair" rgba="1 0 0 0"/>
|
||||
</body>
|
||||
|
||||
<camera name="left_pillar" pos="-0.5 0.2 0.6" fovy="78" mode="targetbody" target="table"/>
|
||||
<camera name="right_pillar" pos="0.5 0.2 0.6" fovy="78" mode="targetbody" target="table"/>
|
||||
<camera name="top" pos="0 0.6 0.8" fovy="78" mode="targetbody" target="table"/>
|
||||
<camera name="angle" pos="0 0 0.6" fovy="78" mode="targetbody" target="table"/>
|
||||
<camera name="front_close" pos="0 0.2 0.4" fovy="78" mode="targetbody" target="vx300s_left/camera_focus"/>
|
||||
|
||||
</worldbody>
|
||||
|
||||
|
||||
|
||||
</mujocoinclude>
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,17 @@
|
|||
<mujocoinclude>
|
||||
<compiler angle="radian" inertiafromgeom="auto" inertiagrouprange="4 5"/>
|
||||
<asset>
|
||||
<mesh name="vx300s_1_base" file="vx300s_1_base.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_2_shoulder" file="vx300s_2_shoulder.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_3_upper_arm" file="vx300s_3_upper_arm.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_4_upper_forearm" file="vx300s_4_upper_forearm.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_5_lower_forearm" file="vx300s_5_lower_forearm.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_6_wrist" file="vx300s_6_wrist.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_7_gripper" file="vx300s_7_gripper.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_8_gripper_prop" file="vx300s_8_gripper_prop.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_9_gripper_bar" file="vx300s_9_gripper_bar.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_10_gripper_finger_left" file="vx300s_10_custom_finger_left.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_10_gripper_finger_right" file="vx300s_10_custom_finger_right.stl" scale="0.001 0.001 0.001" />
|
||||
</asset>
|
||||
|
||||
</mujocoinclude>
|
|
@ -0,0 +1,59 @@
|
|||
|
||||
<mujocoinclude>
|
||||
<body name="vx300s_left" pos="-0.469 0.5 0">
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_1_base" name="vx300s_left/1_base" contype="0" conaffinity="0"/>
|
||||
<body name="vx300s_left/shoulder_link" pos="0 0 0.079">
|
||||
<inertial pos="0.000259233 -3.3552e-06 0.0116129" quat="-0.476119 0.476083 0.52279 0.522826" mass="0.798614" diaginertia="0.00120156 0.00113744 0.0009388" />
|
||||
<joint name="vx300s_left/waist" pos="0 0 0" axis="0 0 1" limited="true" range="-3.14158 3.14158" frictionloss="50" />
|
||||
<geom pos="0 0 -0.003" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_2_shoulder" name="vx300s_left/2_shoulder" />
|
||||
<body name="vx300s_left/upper_arm_link" pos="0 0 0.04805">
|
||||
<inertial pos="0.0206949 4e-10 0.226459" quat="0 0.0728458 0 0.997343" mass="0.792592" diaginertia="0.00911338 0.008925 0.000759317" />
|
||||
<joint name="vx300s_left/shoulder" pos="0 0 0" axis="0 1 0" limited="true" range="-1.85005 1.25664" frictionloss="60" />
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_3_upper_arm" name="vx300s_left/3_upper_arm"/>
|
||||
<body name="vx300s_left/upper_forearm_link" pos="0.05955 0 0.3">
|
||||
<inertial pos="0.105723 0 0" quat="-0.000621631 0.704724 0.0105292 0.709403" mass="0.322228" diaginertia="0.00144107 0.00134228 0.000152047" />
|
||||
<joint name="vx300s_left/elbow" pos="0 0 0" axis="0 1 0" limited="true" range="-1.76278 1.6057" frictionloss="60" />
|
||||
<geom type="mesh" mesh="vx300s_4_upper_forearm" name="vx300s_left/4_upper_forearm" />
|
||||
<body name="vx300s_left/lower_forearm_link" pos="0.2 0 0">
|
||||
<inertial pos="0.0513477 0.00680462 0" quat="-0.702604 -0.0796724 -0.702604 0.0796724" mass="0.414823" diaginertia="0.0005911 0.000546493 0.000155707" />
|
||||
<joint name="vx300s_left/forearm_roll" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
|
||||
<geom quat="0 1 0 0" type="mesh" mesh="vx300s_5_lower_forearm" name="vx300s_left/5_lower_forearm"/>
|
||||
<body name="vx300s_left/wrist_link" pos="0.1 0 0">
|
||||
<inertial pos="0.046743 -7.6652e-06 0.010565" quat="-0.00100191 0.544586 0.0026583 0.8387" mass="0.115395" diaginertia="5.45707e-05 4.63101e-05 4.32692e-05" />
|
||||
<joint name="vx300s_left/wrist_angle" pos="0 0 0" axis="0 1 0" limited="true" range="-1.8675 2.23402" frictionloss="30" />
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_6_wrist" name="vx300s_left/6_wrist" />
|
||||
<body name="vx300s_left/gripper_link" pos="0.069744 0 0">
|
||||
<body name="vx300s_left/camera_focus" pos="0.15 0 0.01">
|
||||
<site pos="0 0 0" size="0.01" type="sphere" name="left_cam_focus" rgba="0 0 1 0"/>
|
||||
</body>
|
||||
<site pos="0.15 0 0" size="0.003 0.003 0.03" type="box" name="cali_left_site1" rgba="0 0 1 0"/>
|
||||
<site pos="0.15 0 0" size="0.003 0.03 0.003" type="box" name="cali_left_site2" rgba="0 0 1 0"/>
|
||||
<site pos="0.15 0 0" size="0.03 0.003 0.003" type="box" name="cali_left_site3" rgba="0 0 1 0"/>
|
||||
<camera name="left_wrist" pos="-0.1 0 0.16" fovy="20" mode="targetbody" target="vx300s_left/camera_focus"/>
|
||||
<inertial pos="0.0395662 -2.56311e-07 0.00400649" quat="0.62033 0.619916 -0.339682 0.339869" mass="0.251652" diaginertia="0.000689546 0.000650316 0.000468142" />
|
||||
<joint name="vx300s_left/wrist_rotate" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
|
||||
<geom pos="-0.02 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_7_gripper" name="vx300s_left/7_gripper" />
|
||||
<geom pos="-0.020175 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_9_gripper_bar" name="vx300s_left/9_gripper_bar" />
|
||||
<body name="vx300s_left/gripper_prop_link" pos="0.0485 0 0">
|
||||
<inertial pos="0.002378 2.85e-08 0" quat="0 0 0.897698 0.440611" mass="0.008009" diaginertia="4.2979e-06 2.8868e-06 1.5314e-06" />
|
||||
<!-- <joint name="vx300s_left/gripper" pos="0 0 0" axis="1 0 0" frictionloss="30" />-->
|
||||
<geom pos="-0.0685 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_8_gripper_prop" name="vx300s_left/8_gripper_prop" />
|
||||
</body>
|
||||
<body name="vx300s_left/left_finger_link" pos="0.0687 0 0">
|
||||
<inertial pos="0.017344 -0.0060692 0" quat="0.449364 0.449364 -0.54596 -0.54596" mass="0.034796" diaginertia="2.48003e-05 1.417e-05 1.20797e-05" />
|
||||
<joint name="vx300s_left/left_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="0.021 0.057" frictionloss="30" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 -0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_left" name="vx300s_left/10_left_gripper_finger"/>
|
||||
</body>
|
||||
<body name="vx300s_left/right_finger_link" pos="0.0687 0 0">
|
||||
<inertial pos="0.017344 0.0060692 0" quat="0.44937 -0.44937 0.545955 -0.545955" mass="0.034796" diaginertia="2.48002e-05 1.417e-05 1.20798e-05" />
|
||||
<joint name="vx300s_left/right_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="-0.057 -0.021" frictionloss="30" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_right" name="vx300s_left/10_right_gripper_finger"/>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</mujocoinclude>
|
|
@ -0,0 +1,59 @@
|
|||
|
||||
<mujocoinclude>
|
||||
<body name="vx300s_right" pos="0.469 0.5 0" euler="0 0 3.1416">
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_1_base" name="vx300s_right/1_base" contype="0" conaffinity="0"/>
|
||||
<body name="vx300s_right/shoulder_link" pos="0 0 0.079">
|
||||
<inertial pos="0.000259233 -3.3552e-06 0.0116129" quat="-0.476119 0.476083 0.52279 0.522826" mass="0.798614" diaginertia="0.00120156 0.00113744 0.0009388" />
|
||||
<joint name="vx300s_right/waist" pos="0 0 0" axis="0 0 1" limited="true" range="-3.14158 3.14158" frictionloss="50" />
|
||||
<geom pos="0 0 -0.003" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_2_shoulder" name="vx300s_right/2_shoulder" />
|
||||
<body name="vx300s_right/upper_arm_link" pos="0 0 0.04805">
|
||||
<inertial pos="0.0206949 4e-10 0.226459" quat="0 0.0728458 0 0.997343" mass="0.792592" diaginertia="0.00911338 0.008925 0.000759317" />
|
||||
<joint name="vx300s_right/shoulder" pos="0 0 0" axis="0 1 0" limited="true" range="-1.85005 1.25664" frictionloss="60" />
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_3_upper_arm" name="vx300s_right/3_upper_arm"/>
|
||||
<body name="vx300s_right/upper_forearm_link" pos="0.05955 0 0.3">
|
||||
<inertial pos="0.105723 0 0" quat="-0.000621631 0.704724 0.0105292 0.709403" mass="0.322228" diaginertia="0.00144107 0.00134228 0.000152047" />
|
||||
<joint name="vx300s_right/elbow" pos="0 0 0" axis="0 1 0" limited="true" range="-1.76278 1.6057" frictionloss="60" />
|
||||
<geom type="mesh" mesh="vx300s_4_upper_forearm" name="vx300s_right/4_upper_forearm" />
|
||||
<body name="vx300s_right/lower_forearm_link" pos="0.2 0 0">
|
||||
<inertial pos="0.0513477 0.00680462 0" quat="-0.702604 -0.0796724 -0.702604 0.0796724" mass="0.414823" diaginertia="0.0005911 0.000546493 0.000155707" />
|
||||
<joint name="vx300s_right/forearm_roll" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
|
||||
<geom quat="0 1 0 0" type="mesh" mesh="vx300s_5_lower_forearm" name="vx300s_right/5_lower_forearm"/>
|
||||
<body name="vx300s_right/wrist_link" pos="0.1 0 0">
|
||||
<inertial pos="0.046743 -7.6652e-06 0.010565" quat="-0.00100191 0.544586 0.0026583 0.8387" mass="0.115395" diaginertia="5.45707e-05 4.63101e-05 4.32692e-05" />
|
||||
<joint name="vx300s_right/wrist_angle" pos="0 0 0" axis="0 1 0" limited="true" range="-1.8675 2.23402" frictionloss="30" />
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_6_wrist" name="vx300s_right/6_wrist" />
|
||||
<body name="vx300s_right/gripper_link" pos="0.069744 0 0">
|
||||
<body name="vx300s_right/camera_focus" pos="0.15 0 0.01">
|
||||
<site pos="0 0 0" size="0.01" type="sphere" name="right_cam_focus" rgba="0 0 1 0"/>
|
||||
</body>
|
||||
<site pos="0.15 0 0" size="0.003 0.003 0.03" type="box" name="cali_right_site1" rgba="0 0 1 0"/>
|
||||
<site pos="0.15 0 0" size="0.003 0.03 0.003" type="box" name="cali_right_site2" rgba="0 0 1 0"/>
|
||||
<site pos="0.15 0 0" size="0.03 0.003 0.003" type="box" name="cali_right_site3" rgba="0 0 1 0"/>
|
||||
<camera name="right_wrist" pos="-0.1 0 0.16" fovy="20" mode="targetbody" target="vx300s_right/camera_focus"/>
|
||||
<inertial pos="0.0395662 -2.56311e-07 0.00400649" quat="0.62033 0.619916 -0.339682 0.339869" mass="0.251652" diaginertia="0.000689546 0.000650316 0.000468142" />
|
||||
<joint name="vx300s_right/wrist_rotate" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
|
||||
<geom pos="-0.02 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_7_gripper" name="vx300s_right/7_gripper" />
|
||||
<geom pos="-0.020175 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_9_gripper_bar" name="vx300s_right/9_gripper_bar" />
|
||||
<body name="vx300s_right/gripper_prop_link" pos="0.0485 0 0">
|
||||
<inertial pos="0.002378 2.85e-08 0" quat="0 0 0.897698 0.440611" mass="0.008009" diaginertia="4.2979e-06 2.8868e-06 1.5314e-06" />
|
||||
<!-- <joint name="vx300s_right/gripper" pos="0 0 0" axis="1 0 0" frictionloss="30" />-->
|
||||
<geom pos="-0.0685 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_8_gripper_prop" name="vx300s_right/8_gripper_prop" />
|
||||
</body>
|
||||
<body name="vx300s_right/left_finger_link" pos="0.0687 0 0">
|
||||
<inertial pos="0.017344 -0.0060692 0" quat="0.449364 0.449364 -0.54596 -0.54596" mass="0.034796" diaginertia="2.48003e-05 1.417e-05 1.20797e-05" />
|
||||
<joint name="vx300s_right/left_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="0.021 0.057" frictionloss="30" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 -0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_left" name="vx300s_right/10_left_gripper_finger"/>
|
||||
</body>
|
||||
<body name="vx300s_right/right_finger_link" pos="0.0687 0 0">
|
||||
<inertial pos="0.017344 0.0060692 0" quat="0.44937 -0.44937 0.545955 -0.545955" mass="0.034796" diaginertia="2.48002e-05 1.417e-05 1.20798e-05" />
|
||||
<joint name="vx300s_right/right_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="-0.057 -0.021" frictionloss="30" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_right" name="vx300s_right/10_right_gripper_finger"/>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</mujocoinclude>
|
|
@ -0,0 +1,163 @@
|
|||
from pathlib import Path
|
||||
|
||||
### Simulation envs fixed constants
|
||||
DT = 0.02 # 0.02 ms -> 1/0.2 = 50 hz
|
||||
FPS = 50
|
||||
|
||||
|
||||
JOINTS = [
|
||||
# absolute joint position
|
||||
"left_arm_waist",
|
||||
"left_arm_shoulder",
|
||||
"left_arm_elbow",
|
||||
"left_arm_forearm_roll",
|
||||
"left_arm_wrist_angle",
|
||||
"left_arm_wrist_rotate",
|
||||
# normalized gripper position 0: close, 1: open
|
||||
"left_arm_gripper",
|
||||
# absolute joint position
|
||||
"right_arm_waist",
|
||||
"right_arm_shoulder",
|
||||
"right_arm_elbow",
|
||||
"right_arm_forearm_roll",
|
||||
"right_arm_wrist_angle",
|
||||
"right_arm_wrist_rotate",
|
||||
# normalized gripper position 0: close, 1: open
|
||||
"right_arm_gripper",
|
||||
]
|
||||
|
||||
ACTIONS = [
|
||||
# position and quaternion for end effector
|
||||
"left_arm_waist",
|
||||
"left_arm_shoulder",
|
||||
"left_arm_elbow",
|
||||
"left_arm_forearm_roll",
|
||||
"left_arm_wrist_angle",
|
||||
"left_arm_wrist_rotate",
|
||||
# normalized gripper position (0: close, 1: open)
|
||||
"left_arm_gripper",
|
||||
"right_arm_waist",
|
||||
"right_arm_shoulder",
|
||||
"right_arm_elbow",
|
||||
"right_arm_forearm_roll",
|
||||
"right_arm_wrist_angle",
|
||||
"right_arm_wrist_rotate",
|
||||
# normalized gripper position (0: close, 1: open)
|
||||
"right_arm_gripper",
|
||||
]
|
||||
|
||||
|
||||
START_ARM_POSE = [
|
||||
0,
|
||||
-0.96,
|
||||
1.16,
|
||||
0,
|
||||
-0.3,
|
||||
0,
|
||||
0.02239,
|
||||
-0.02239,
|
||||
0,
|
||||
-0.96,
|
||||
1.16,
|
||||
0,
|
||||
-0.3,
|
||||
0,
|
||||
0.02239,
|
||||
-0.02239,
|
||||
]
|
||||
|
||||
ASSETS_DIR = Path(__file__).parent.resolve() / "assets" # note: absolute path
|
||||
|
||||
# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
|
||||
MASTER_GRIPPER_POSITION_OPEN = 0.02417
|
||||
MASTER_GRIPPER_POSITION_CLOSE = 0.01244
|
||||
PUPPET_GRIPPER_POSITION_OPEN = 0.05800
|
||||
PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
|
||||
|
||||
# Gripper joint limits (qpos[6])
|
||||
MASTER_GRIPPER_JOINT_OPEN = 0.3083
|
||||
MASTER_GRIPPER_JOINT_CLOSE = -0.6842
|
||||
PUPPET_GRIPPER_JOINT_OPEN = 1.4910
|
||||
PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
|
||||
|
||||
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2
|
||||
|
||||
############################ Helper functions ############################
|
||||
|
||||
|
||||
def normalize_master_gripper_position(x):
|
||||
return (x - MASTER_GRIPPER_POSITION_CLOSE) / (
|
||||
MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
|
||||
|
||||
def normalize_puppet_gripper_position(x):
|
||||
return (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
|
||||
PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
|
||||
|
||||
def unnormalize_master_gripper_position(x):
|
||||
return x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
|
||||
|
||||
|
||||
def unnormalize_puppet_gripper_position(x):
|
||||
return x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
|
||||
|
||||
|
||||
def convert_position_from_master_to_puppet(x):
|
||||
return unnormalize_puppet_gripper_position(normalize_master_gripper_position(x))
|
||||
|
||||
|
||||
def normalizer_master_gripper_joint(x):
|
||||
return (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
|
||||
|
||||
def normalize_puppet_gripper_joint(x):
|
||||
return (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
|
||||
|
||||
def unnormalize_master_gripper_joint(x):
|
||||
return x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
|
||||
|
||||
|
||||
def unnormalize_puppet_gripper_joint(x):
|
||||
return x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
|
||||
|
||||
|
||||
def convert_join_from_master_to_puppet(x):
|
||||
return unnormalize_puppet_gripper_joint(normalizer_master_gripper_joint(x))
|
||||
|
||||
|
||||
def normalize_master_gripper_velocity(x):
|
||||
return x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
|
||||
|
||||
|
||||
def normalize_puppet_gripper_velocity(x):
|
||||
return x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
|
||||
|
||||
|
||||
def convert_master_from_position_to_joint(x):
|
||||
return (
|
||||
normalize_master_gripper_position(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
+ MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
|
||||
|
||||
def convert_master_from_joint_to_position(x):
|
||||
return unnormalize_master_gripper_position(
|
||||
(x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
)
|
||||
|
||||
|
||||
def convert_puppet_from_position_to_join(x):
|
||||
return (
|
||||
normalize_puppet_gripper_position(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
+ PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
|
||||
|
||||
def convert_puppet_from_joint_to_position(x):
|
||||
return unnormalize_puppet_gripper_position(
|
||||
(x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
)
|
|
@ -0,0 +1,311 @@
|
|||
import importlib
|
||||
import logging
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
from dm_control import mujoco
|
||||
from dm_control.rl import control
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.tensor_specs import (
|
||||
BoundedTensorSpec,
|
||||
CompositeSpec,
|
||||
DiscreteTensorSpec,
|
||||
UnboundedContinuousTensorSpec,
|
||||
)
|
||||
|
||||
from lerobot.common.envs.abstract import AbstractEnv
|
||||
from lerobot.common.envs.aloha.constants import (
|
||||
ACTIONS,
|
||||
ASSETS_DIR,
|
||||
DT,
|
||||
JOINTS,
|
||||
)
|
||||
from lerobot.common.envs.aloha.tasks.sim import BOX_POSE, InsertionTask, TransferCubeTask
|
||||
from lerobot.common.envs.aloha.tasks.sim_end_effector import (
|
||||
InsertionEndEffectorTask,
|
||||
TransferCubeEndEffectorTask,
|
||||
)
|
||||
from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
|
||||
from lerobot.common.utils import set_seed
|
||||
|
||||
_has_gym = importlib.util.find_spec("gym") is not None
|
||||
|
||||
|
||||
class AlohaEnv(AbstractEnv):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
frame_skip: int = 1,
|
||||
from_pixels: bool = False,
|
||||
pixels_only: bool = False,
|
||||
image_size=None,
|
||||
seed=1337,
|
||||
device="cpu",
|
||||
num_prev_obs=1,
|
||||
num_prev_action=0,
|
||||
):
|
||||
super().__init__(
|
||||
task=task,
|
||||
frame_skip=frame_skip,
|
||||
from_pixels=from_pixels,
|
||||
pixels_only=pixels_only,
|
||||
image_size=image_size,
|
||||
seed=seed,
|
||||
device=device,
|
||||
num_prev_obs=num_prev_obs,
|
||||
num_prev_action=num_prev_action,
|
||||
)
|
||||
|
||||
def _make_env(self):
|
||||
if not _has_gym:
|
||||
raise ImportError("Cannot import gym.")
|
||||
|
||||
if not self.from_pixels:
|
||||
raise NotImplementedError()
|
||||
|
||||
self._env = self._make_env_task(self.task)
|
||||
|
||||
def render(self, mode="rgb_array", width=640, height=480):
|
||||
# TODO(rcadene): render and visualizer several cameras (e.g. angle, front_close)
|
||||
image = self._env.physics.render(height=height, width=width, camera_id="top")
|
||||
return image
|
||||
|
||||
def _make_env_task(self, task_name):
|
||||
# time limit is controlled by StepCounter in env factory
|
||||
time_limit = float("inf")
|
||||
|
||||
if "sim_transfer_cube" in task_name:
|
||||
xml_path = ASSETS_DIR / "bimanual_viperx_transfer_cube.xml"
|
||||
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
||||
task = TransferCubeTask(random=False)
|
||||
elif "sim_insertion" in task_name:
|
||||
xml_path = ASSETS_DIR / "bimanual_viperx_insertion.xml"
|
||||
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
||||
task = InsertionTask(random=False)
|
||||
elif "sim_end_effector_transfer_cube" in task_name:
|
||||
raise NotImplementedError()
|
||||
xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_transfer_cube.xml"
|
||||
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
||||
task = TransferCubeEndEffectorTask(random=False)
|
||||
elif "sim_end_effector_insertion" in task_name:
|
||||
raise NotImplementedError()
|
||||
xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_insertion.xml"
|
||||
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
||||
task = InsertionEndEffectorTask(random=False)
|
||||
else:
|
||||
raise NotImplementedError(task_name)
|
||||
|
||||
env = control.Environment(
|
||||
physics, task, time_limit, control_timestep=DT, n_sub_steps=None, flat_observation=False
|
||||
)
|
||||
return env
|
||||
|
||||
def _format_raw_obs(self, raw_obs):
|
||||
if self.from_pixels:
|
||||
image = torch.from_numpy(raw_obs["images"]["top"].copy())
|
||||
image = einops.rearrange(image, "h w c -> c h w")
|
||||
assert image.dtype == torch.uint8
|
||||
obs = {"image": {"top": image}}
|
||||
|
||||
if not self.pixels_only:
|
||||
obs["state"] = torch.from_numpy(raw_obs["qpos"]).type(torch.float32)
|
||||
else:
|
||||
# TODO(rcadene):
|
||||
raise NotImplementedError()
|
||||
# obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)}
|
||||
|
||||
return obs
|
||||
|
||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||
td = tensordict
|
||||
if td is None or td.is_empty():
|
||||
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
||||
self._current_seed += 1
|
||||
self.set_seed(self._current_seed)
|
||||
|
||||
# TODO(rcadene): do not use global variable for this
|
||||
if "sim_transfer_cube" in self.task:
|
||||
BOX_POSE[0] = sample_box_pose() # used in sim reset
|
||||
elif "sim_insertion" in self.task:
|
||||
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
|
||||
|
||||
raw_obs = self._env.reset()
|
||||
# TODO(rcadene): add assert
|
||||
# assert self._current_seed == self._env._seed
|
||||
|
||||
obs = self._format_raw_obs(raw_obs.observation)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue = deque(
|
||||
[obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue = deque(
|
||||
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"done": torch.tensor([False], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.call_rendering_hooks()
|
||||
return td
|
||||
|
||||
def _step(self, tensordict: TensorDict):
|
||||
td = tensordict
|
||||
action = td["action"].numpy()
|
||||
# step expects shape=(4,) so we pad if necessary
|
||||
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||
sum_reward = 0
|
||||
|
||||
if action.ndim == 1:
|
||||
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
|
||||
else:
|
||||
if self.frame_skip > 1:
|
||||
raise NotImplementedError()
|
||||
|
||||
num_action_steps = action.shape[0]
|
||||
for i in range(num_action_steps):
|
||||
_, reward, discount, raw_obs = self._env.step(action[i])
|
||||
del discount # not used
|
||||
|
||||
# TOOD(rcadene): add an enum
|
||||
success = done = reward == 4
|
||||
sum_reward += reward
|
||||
obs = self._format_raw_obs(raw_obs)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue.append(obs["image"]["top"])
|
||||
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue.append(obs["state"])
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
self.call_rendering_hooks()
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"reward": torch.tensor([sum_reward], dtype=torch.float32),
|
||||
# succes and done are true when coverage > self.success_threshold in env
|
||||
"done": torch.tensor([done], dtype=torch.bool),
|
||||
"success": torch.tensor([success], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
return td
|
||||
|
||||
def _make_spec(self):
|
||||
obs = {}
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
if self.from_pixels:
|
||||
if isinstance(self.image_size, int):
|
||||
image_shape = (3, self.image_size, self.image_size)
|
||||
elif OmegaConf.is_list(self.image_size):
|
||||
assert len(self.image_size) == 3 # c h w
|
||||
assert self.image_size[0] == 3 # c is RGB
|
||||
image_shape = tuple(self.image_size)
|
||||
else:
|
||||
raise ValueError(self.image_size)
|
||||
if self.num_prev_obs > 0:
|
||||
image_shape = (self.num_prev_obs + 1, *image_shape)
|
||||
|
||||
obs["image"] = {
|
||||
"top": BoundedTensorSpec(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=image_shape,
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
}
|
||||
if not self.pixels_only:
|
||||
state_shape = (len(JOINTS),)
|
||||
if self.num_prev_obs > 0:
|
||||
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||
|
||||
obs["state"] = UnboundedContinuousTensorSpec(
|
||||
# TODO: add low and high bounds
|
||||
shape=state_shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
|
||||
state_shape = (len(JOINTS),)
|
||||
if self.num_prev_obs > 0:
|
||||
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||
|
||||
obs["state"] = UnboundedContinuousTensorSpec(
|
||||
# TODO: add low and high bounds
|
||||
shape=state_shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
self.observation_spec = CompositeSpec({"observation": obs})
|
||||
|
||||
# TODO(rcadene): valid when controling end effector?
|
||||
# action_space = self._env.action_spec()
|
||||
# self.action_spec = BoundedTensorSpec(
|
||||
# low=action_space.minimum,
|
||||
# high=action_space.maximum,
|
||||
# shape=action_space.shape,
|
||||
# dtype=torch.float32,
|
||||
# device=self.device,
|
||||
# )
|
||||
|
||||
# TODO(rcaene): add bounds (where are they????)
|
||||
self.action_spec = BoundedTensorSpec(
|
||||
shape=(len(ACTIONS)),
|
||||
low=-1,
|
||||
high=1,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.reward_spec = UnboundedContinuousTensorSpec(
|
||||
shape=(1,),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.done_spec = CompositeSpec(
|
||||
{
|
||||
"done": DiscreteTensorSpec(
|
||||
2,
|
||||
shape=(1,),
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
),
|
||||
"success": DiscreteTensorSpec(
|
||||
2,
|
||||
shape=(1,),
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def _set_seed(self, seed: Optional[int]):
|
||||
set_seed(seed)
|
||||
# TODO(rcadene): seed the env
|
||||
# self._env.seed(seed)
|
||||
logging.warning("Aloha env is not seeded")
|
|
@ -0,0 +1,219 @@
|
|||
import collections
|
||||
|
||||
import numpy as np
|
||||
from dm_control.suite import base
|
||||
|
||||
from lerobot.common.envs.aloha.constants import (
|
||||
START_ARM_POSE,
|
||||
normalize_puppet_gripper_position,
|
||||
normalize_puppet_gripper_velocity,
|
||||
unnormalize_puppet_gripper_position,
|
||||
)
|
||||
|
||||
BOX_POSE = [None] # to be changed from outside
|
||||
|
||||
"""
|
||||
Environment for simulated robot bi-manual manipulation, with joint position control
|
||||
Action space: [left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||
|
||||
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
|
||||
"""
|
||||
|
||||
|
||||
class BimanualViperXTask(base.Task):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
|
||||
def before_step(self, action, physics):
|
||||
left_arm_action = action[:6]
|
||||
right_arm_action = action[7 : 7 + 6]
|
||||
normalized_left_gripper_action = action[6]
|
||||
normalized_right_gripper_action = action[7 + 6]
|
||||
|
||||
left_gripper_action = unnormalize_puppet_gripper_position(normalized_left_gripper_action)
|
||||
right_gripper_action = unnormalize_puppet_gripper_position(normalized_right_gripper_action)
|
||||
|
||||
full_left_gripper_action = [left_gripper_action, -left_gripper_action]
|
||||
full_right_gripper_action = [right_gripper_action, -right_gripper_action]
|
||||
|
||||
env_action = np.concatenate(
|
||||
[left_arm_action, full_left_gripper_action, right_arm_action, full_right_gripper_action]
|
||||
)
|
||||
super().before_step(env_action, physics)
|
||||
return
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_qpos(physics):
|
||||
qpos_raw = physics.data.qpos.copy()
|
||||
left_qpos_raw = qpos_raw[:8]
|
||||
right_qpos_raw = qpos_raw[8:16]
|
||||
left_arm_qpos = left_qpos_raw[:6]
|
||||
right_arm_qpos = right_qpos_raw[:6]
|
||||
left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])]
|
||||
right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])]
|
||||
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||
|
||||
@staticmethod
|
||||
def get_qvel(physics):
|
||||
qvel_raw = physics.data.qvel.copy()
|
||||
left_qvel_raw = qvel_raw[:8]
|
||||
right_qvel_raw = qvel_raw[8:16]
|
||||
left_arm_qvel = left_qvel_raw[:6]
|
||||
right_arm_qvel = right_qvel_raw[:6]
|
||||
left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])]
|
||||
right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])]
|
||||
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_observation(self, physics):
|
||||
obs = collections.OrderedDict()
|
||||
obs["qpos"] = self.get_qpos(physics)
|
||||
obs["qvel"] = self.get_qvel(physics)
|
||||
obs["env_state"] = self.get_env_state(physics)
|
||||
obs["images"] = {}
|
||||
obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top")
|
||||
obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle")
|
||||
obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close")
|
||||
|
||||
return obs
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether left gripper is holding the box
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TransferCubeTask(BimanualViperXTask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
|
||||
# reset qpos, control and box position
|
||||
with physics.reset_context():
|
||||
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||
np.copyto(physics.data.ctrl, START_ARM_POSE)
|
||||
assert BOX_POSE[0] is not None
|
||||
physics.named.data.qpos[-7:] = BOX_POSE[0]
|
||||
# print(f"{BOX_POSE=}")
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether left gripper is holding the box
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_table = ("red_box", "table") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_right_gripper:
|
||||
reward = 1
|
||||
if touch_right_gripper and not touch_table: # lifted
|
||||
reward = 2
|
||||
if touch_left_gripper: # attempted transfer
|
||||
reward = 3
|
||||
if touch_left_gripper and not touch_table: # successful transfer
|
||||
reward = 4
|
||||
return reward
|
||||
|
||||
|
||||
class InsertionTask(BimanualViperXTask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
|
||||
# reset qpos, control and box position
|
||||
with physics.reset_context():
|
||||
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||
np.copyto(physics.data.ctrl, START_ARM_POSE)
|
||||
assert BOX_POSE[0] is not None
|
||||
physics.named.data.qpos[-7 * 2 :] = BOX_POSE[0] # two objects
|
||||
# print(f"{BOX_POSE=}")
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether peg touches the pin
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_left_gripper = (
|
||||
("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
)
|
||||
|
||||
peg_touch_table = ("red_peg", "table") in all_contact_pairs
|
||||
socket_touch_table = (
|
||||
("socket-1", "table") in all_contact_pairs
|
||||
or ("socket-2", "table") in all_contact_pairs
|
||||
or ("socket-3", "table") in all_contact_pairs
|
||||
or ("socket-4", "table") in all_contact_pairs
|
||||
)
|
||||
peg_touch_socket = (
|
||||
("red_peg", "socket-1") in all_contact_pairs
|
||||
or ("red_peg", "socket-2") in all_contact_pairs
|
||||
or ("red_peg", "socket-3") in all_contact_pairs
|
||||
or ("red_peg", "socket-4") in all_contact_pairs
|
||||
)
|
||||
pin_touched = ("red_peg", "pin") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_left_gripper and touch_right_gripper: # touch both
|
||||
reward = 1
|
||||
if (
|
||||
touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table)
|
||||
): # grasp both
|
||||
reward = 2
|
||||
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
|
||||
reward = 3
|
||||
if pin_touched: # successful insertion
|
||||
reward = 4
|
||||
return reward
|
|
@ -0,0 +1,263 @@
|
|||
import collections
|
||||
|
||||
import numpy as np
|
||||
from dm_control.suite import base
|
||||
|
||||
from lerobot.common.envs.aloha.constants import (
|
||||
PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
START_ARM_POSE,
|
||||
normalize_puppet_gripper_position,
|
||||
normalize_puppet_gripper_velocity,
|
||||
unnormalize_puppet_gripper_position,
|
||||
)
|
||||
from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
|
||||
|
||||
"""
|
||||
Environment for simulated robot bi-manual manipulation, with end-effector control.
|
||||
Action space: [left_arm_pose (7), # position and quaternion for end effector
|
||||
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_pose (7), # position and quaternion for end effector
|
||||
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||
|
||||
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
|
||||
"""
|
||||
|
||||
|
||||
class BimanualViperXEndEffectorTask(base.Task):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
|
||||
def before_step(self, action, physics):
|
||||
a_len = len(action) // 2
|
||||
action_left = action[:a_len]
|
||||
action_right = action[a_len:]
|
||||
|
||||
# set mocap position and quat
|
||||
# left
|
||||
np.copyto(physics.data.mocap_pos[0], action_left[:3])
|
||||
np.copyto(physics.data.mocap_quat[0], action_left[3:7])
|
||||
# right
|
||||
np.copyto(physics.data.mocap_pos[1], action_right[:3])
|
||||
np.copyto(physics.data.mocap_quat[1], action_right[3:7])
|
||||
|
||||
# set gripper
|
||||
g_left_ctrl = unnormalize_puppet_gripper_position(action_left[7])
|
||||
g_right_ctrl = unnormalize_puppet_gripper_position(action_right[7])
|
||||
np.copyto(physics.data.ctrl, np.array([g_left_ctrl, -g_left_ctrl, g_right_ctrl, -g_right_ctrl]))
|
||||
|
||||
def initialize_robots(self, physics):
|
||||
# reset joint position
|
||||
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||
|
||||
# reset mocap to align with end effector
|
||||
# to obtain these numbers:
|
||||
# (1) make an ee_sim env and reset to the same start_pose
|
||||
# (2) get env._physics.named.data.xpos['vx300s_left/gripper_link']
|
||||
# get env._physics.named.data.xquat['vx300s_left/gripper_link']
|
||||
# repeat the same for right side
|
||||
np.copyto(physics.data.mocap_pos[0], [-0.31718881, 0.5, 0.29525084])
|
||||
np.copyto(physics.data.mocap_quat[0], [1, 0, 0, 0])
|
||||
# right
|
||||
np.copyto(physics.data.mocap_pos[1], np.array([0.31718881, 0.49999888, 0.29525084]))
|
||||
np.copyto(physics.data.mocap_quat[1], [1, 0, 0, 0])
|
||||
|
||||
# reset gripper control
|
||||
close_gripper_control = np.array(
|
||||
[
|
||||
PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
-PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
-PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
]
|
||||
)
|
||||
np.copyto(physics.data.ctrl, close_gripper_control)
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_qpos(physics):
|
||||
qpos_raw = physics.data.qpos.copy()
|
||||
left_qpos_raw = qpos_raw[:8]
|
||||
right_qpos_raw = qpos_raw[8:16]
|
||||
left_arm_qpos = left_qpos_raw[:6]
|
||||
right_arm_qpos = right_qpos_raw[:6]
|
||||
left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])]
|
||||
right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])]
|
||||
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||
|
||||
@staticmethod
|
||||
def get_qvel(physics):
|
||||
qvel_raw = physics.data.qvel.copy()
|
||||
left_qvel_raw = qvel_raw[:8]
|
||||
right_qvel_raw = qvel_raw[8:16]
|
||||
left_arm_qvel = left_qvel_raw[:6]
|
||||
right_arm_qvel = right_qvel_raw[:6]
|
||||
left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])]
|
||||
right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])]
|
||||
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_observation(self, physics):
|
||||
# note: it is important to do .copy()
|
||||
obs = collections.OrderedDict()
|
||||
obs["qpos"] = self.get_qpos(physics)
|
||||
obs["qvel"] = self.get_qvel(physics)
|
||||
obs["env_state"] = self.get_env_state(physics)
|
||||
obs["images"] = {}
|
||||
obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top")
|
||||
obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle")
|
||||
obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close")
|
||||
# used in scripted policy to obtain starting pose
|
||||
obs["mocap_pose_left"] = np.concatenate(
|
||||
[physics.data.mocap_pos[0], physics.data.mocap_quat[0]]
|
||||
).copy()
|
||||
obs["mocap_pose_right"] = np.concatenate(
|
||||
[physics.data.mocap_pos[1], physics.data.mocap_quat[1]]
|
||||
).copy()
|
||||
|
||||
# used when replaying joint trajectory
|
||||
obs["gripper_ctrl"] = physics.data.ctrl.copy()
|
||||
return obs
|
||||
|
||||
def get_reward(self, physics):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TransferCubeEndEffectorTask(BimanualViperXEndEffectorTask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
self.initialize_robots(physics)
|
||||
# randomize box position
|
||||
cube_pose = sample_box_pose()
|
||||
box_start_idx = physics.model.name2id("red_box_joint", "joint")
|
||||
np.copyto(physics.data.qpos[box_start_idx : box_start_idx + 7], cube_pose)
|
||||
# print(f"randomized cube position to {cube_position}")
|
||||
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether left gripper is holding the box
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_table = ("red_box", "table") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_right_gripper:
|
||||
reward = 1
|
||||
if touch_right_gripper and not touch_table: # lifted
|
||||
reward = 2
|
||||
if touch_left_gripper: # attempted transfer
|
||||
reward = 3
|
||||
if touch_left_gripper and not touch_table: # successful transfer
|
||||
reward = 4
|
||||
return reward
|
||||
|
||||
|
||||
class InsertionEndEffectorTask(BimanualViperXEndEffectorTask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
self.initialize_robots(physics)
|
||||
# randomize peg and socket position
|
||||
peg_pose, socket_pose = sample_insertion_pose()
|
||||
|
||||
def id2index(j_id):
|
||||
return 16 + (j_id - 16) * 7 # first 16 is robot qpos, 7 is pose dim # hacky
|
||||
|
||||
peg_start_id = physics.model.name2id("red_peg_joint", "joint")
|
||||
peg_start_idx = id2index(peg_start_id)
|
||||
np.copyto(physics.data.qpos[peg_start_idx : peg_start_idx + 7], peg_pose)
|
||||
# print(f"randomized cube position to {cube_position}")
|
||||
|
||||
socket_start_id = physics.model.name2id("blue_socket_joint", "joint")
|
||||
socket_start_idx = id2index(socket_start_id)
|
||||
np.copyto(physics.data.qpos[socket_start_idx : socket_start_idx + 7], socket_pose)
|
||||
# print(f"randomized cube position to {cube_position}")
|
||||
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether peg touches the pin
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_left_gripper = (
|
||||
("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
)
|
||||
|
||||
peg_touch_table = ("red_peg", "table") in all_contact_pairs
|
||||
socket_touch_table = (
|
||||
("socket-1", "table") in all_contact_pairs
|
||||
or ("socket-2", "table") in all_contact_pairs
|
||||
or ("socket-3", "table") in all_contact_pairs
|
||||
or ("socket-4", "table") in all_contact_pairs
|
||||
)
|
||||
peg_touch_socket = (
|
||||
("red_peg", "socket-1") in all_contact_pairs
|
||||
or ("red_peg", "socket-2") in all_contact_pairs
|
||||
or ("red_peg", "socket-3") in all_contact_pairs
|
||||
or ("red_peg", "socket-4") in all_contact_pairs
|
||||
)
|
||||
pin_touched = ("red_peg", "pin") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_left_gripper and touch_right_gripper: # touch both
|
||||
reward = 1
|
||||
if (
|
||||
touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table)
|
||||
): # grasp both
|
||||
reward = 2
|
||||
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
|
||||
reward = 3
|
||||
if pin_touched: # successful insertion
|
||||
reward = 4
|
||||
return reward
|
|
@ -0,0 +1,39 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
def sample_box_pose():
|
||||
x_range = [0.0, 0.2]
|
||||
y_range = [0.4, 0.6]
|
||||
z_range = [0.05, 0.05]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||
|
||||
cube_quat = np.array([1, 0, 0, 0])
|
||||
return np.concatenate([cube_position, cube_quat])
|
||||
|
||||
|
||||
def sample_insertion_pose():
|
||||
# Peg
|
||||
x_range = [0.1, 0.2]
|
||||
y_range = [0.4, 0.6]
|
||||
z_range = [0.05, 0.05]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||
|
||||
peg_quat = np.array([1, 0, 0, 0])
|
||||
peg_pose = np.concatenate([peg_position, peg_quat])
|
||||
|
||||
# Socket
|
||||
x_range = [-0.2, -0.1]
|
||||
y_range = [0.4, 0.6]
|
||||
z_range = [0.05, 0.05]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||
|
||||
socket_quat = np.array([1, 0, 0, 0])
|
||||
socket_pose = np.concatenate([socket_position, socket_quat])
|
||||
|
||||
return peg_pose, socket_pose
|
|
@ -23,6 +23,11 @@ def make_env(cfg, transform=None):
|
|||
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
|
||||
|
||||
clsfunc = PushtEnv
|
||||
elif cfg.env.name == "aloha":
|
||||
from lerobot.common.envs.aloha.env import AlohaEnv
|
||||
|
||||
kwargs["task"] = cfg.env.task
|
||||
clsfunc = AlohaEnv
|
||||
else:
|
||||
raise ValueError(cfg.env.name)
|
||||
|
||||
|
|
|
@ -11,61 +11,52 @@ from torchrl.data.tensor_specs import (
|
|||
DiscreteTensorSpec,
|
||||
UnboundedContinuousTensorSpec,
|
||||
)
|
||||
from torchrl.envs import EnvBase
|
||||
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
|
||||
|
||||
from lerobot.common.envs.abstract import AbstractEnv
|
||||
from lerobot.common.utils import set_seed
|
||||
|
||||
_has_gym = importlib.util.find_spec("gym") is not None
|
||||
|
||||
|
||||
class PushtEnv(EnvBase):
|
||||
class PushtEnv(AbstractEnv):
|
||||
def __init__(
|
||||
self,
|
||||
task="pusht",
|
||||
frame_skip: int = 1,
|
||||
from_pixels: bool = False,
|
||||
pixels_only: bool = False,
|
||||
image_size=None,
|
||||
seed=1337,
|
||||
device="cpu",
|
||||
num_prev_obs=0,
|
||||
num_prev_obs=1,
|
||||
num_prev_action=0,
|
||||
):
|
||||
super().__init__(device=device, batch_size=[])
|
||||
self.frame_skip = frame_skip
|
||||
self.from_pixels = from_pixels
|
||||
self.pixels_only = pixels_only
|
||||
self.image_size = image_size
|
||||
self.num_prev_obs = num_prev_obs
|
||||
self.num_prev_action = num_prev_action
|
||||
|
||||
if pixels_only:
|
||||
assert from_pixels
|
||||
if from_pixels:
|
||||
assert image_size
|
||||
super().__init__(
|
||||
task=task,
|
||||
frame_skip=frame_skip,
|
||||
from_pixels=from_pixels,
|
||||
pixels_only=pixels_only,
|
||||
image_size=image_size,
|
||||
seed=seed,
|
||||
device=device,
|
||||
num_prev_obs=num_prev_obs,
|
||||
num_prev_action=num_prev_action,
|
||||
)
|
||||
|
||||
def _make_env(self):
|
||||
if not _has_gym:
|
||||
raise ImportError("Cannot import gym.")
|
||||
|
||||
# TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on)
|
||||
# from lerobot.common.envs.pusht.pusht_env import PushTEnv
|
||||
|
||||
if not from_pixels:
|
||||
if not self.from_pixels:
|
||||
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
|
||||
from lerobot.common.envs.pusht.pusht_image_env import PushTImageEnv
|
||||
|
||||
self._env = PushTImageEnv(render_size=self.image_size)
|
||||
|
||||
self._make_spec()
|
||||
self._current_seed = self.set_seed(seed)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
|
||||
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
|
||||
if self.num_prev_action > 0:
|
||||
raise NotImplementedError()
|
||||
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
|
||||
|
||||
def render(self, mode="rgb_array", width=384, height=384):
|
||||
if width != height:
|
||||
raise NotImplementedError()
|
||||
|
@ -122,6 +113,8 @@ class PushtEnv(EnvBase):
|
|||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.call_rendering_hooks()
|
||||
return td
|
||||
|
||||
def _step(self, tensordict: TensorDict):
|
||||
|
@ -154,6 +147,8 @@ class PushtEnv(EnvBase):
|
|||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
self.call_rendering_hooks()
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
|
@ -175,9 +170,9 @@ class PushtEnv(EnvBase):
|
|||
|
||||
obs["image"] = BoundedTensorSpec(
|
||||
low=0,
|
||||
high=1,
|
||||
high=255,
|
||||
shape=image_shape,
|
||||
dtype=torch.float32,
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
if not self.pixels_only:
|
||||
|
|
|
@ -25,7 +25,7 @@ class PushTImageEnv(PushTEnv):
|
|||
img = super()._render_frame(mode="rgb_array")
|
||||
|
||||
agent_pos = np.array(self.agent.position)
|
||||
img_obs = np.moveaxis(img.astype(np.float32) / 255, -1, 0)
|
||||
img_obs = np.moveaxis(img, -1, 0)
|
||||
obs = {"image": img_obs, "agent_pos": agent_pos}
|
||||
|
||||
# draw action
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import importlib
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
|
@ -10,9 +12,9 @@ from torchrl.data.tensor_specs import (
|
|||
DiscreteTensorSpec,
|
||||
UnboundedContinuousTensorSpec,
|
||||
)
|
||||
from torchrl.envs import EnvBase
|
||||
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
|
||||
|
||||
from lerobot.common.envs.abstract import AbstractEnv
|
||||
from lerobot.common.utils import set_seed
|
||||
|
||||
MAX_NUM_ACTIONS = 4
|
||||
|
@ -21,7 +23,7 @@ _has_gym = importlib.util.find_spec("gym") is not None
|
|||
_has_simxarm = importlib.util.find_spec("simxarm") is not None and _has_gym
|
||||
|
||||
|
||||
class SimxarmEnv(EnvBase):
|
||||
class SimxarmEnv(AbstractEnv):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
|
@ -31,19 +33,22 @@ class SimxarmEnv(EnvBase):
|
|||
image_size=None,
|
||||
seed=1337,
|
||||
device="cpu",
|
||||
num_prev_obs=0,
|
||||
num_prev_action=0,
|
||||
):
|
||||
super().__init__(device=device, batch_size=[])
|
||||
self.task = task
|
||||
self.frame_skip = frame_skip
|
||||
self.from_pixels = from_pixels
|
||||
self.pixels_only = pixels_only
|
||||
self.image_size = image_size
|
||||
|
||||
if pixels_only:
|
||||
assert from_pixels
|
||||
if from_pixels:
|
||||
assert image_size
|
||||
super().__init__(
|
||||
task=task,
|
||||
frame_skip=frame_skip,
|
||||
from_pixels=from_pixels,
|
||||
pixels_only=pixels_only,
|
||||
image_size=image_size,
|
||||
seed=seed,
|
||||
device=device,
|
||||
num_prev_obs=num_prev_obs,
|
||||
num_prev_action=num_prev_action,
|
||||
)
|
||||
|
||||
def _make_env(self):
|
||||
if not _has_simxarm:
|
||||
raise ImportError("Cannot import simxarm.")
|
||||
if not _has_gym:
|
||||
|
@ -63,9 +68,6 @@ class SimxarmEnv(EnvBase):
|
|||
if "w" not in TASKS[self.task]["action_space"]:
|
||||
self._action_padding[-1] = 1.0
|
||||
|
||||
self._make_spec()
|
||||
self.set_seed(seed)
|
||||
|
||||
def render(self, mode="rgb_array", width=384, height=384):
|
||||
return self._env.render(mode, width=width, height=height)
|
||||
|
||||
|
@ -90,15 +92,33 @@ class SimxarmEnv(EnvBase):
|
|||
if td is None or td.is_empty():
|
||||
raw_obs = self._env.reset()
|
||||
|
||||
obs = self._format_raw_obs(raw_obs)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue = deque(
|
||||
[obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue = deque(
|
||||
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": self._format_raw_obs(raw_obs),
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"done": torch.tensor([False], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.call_rendering_hooks()
|
||||
return td
|
||||
|
||||
def _step(self, tensordict: TensorDict):
|
||||
|
@ -108,10 +128,32 @@ class SimxarmEnv(EnvBase):
|
|||
action = np.concatenate([action, self._action_padding])
|
||||
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||
sum_reward = 0
|
||||
for _ in range(self.frame_skip):
|
||||
raw_obs, reward, done, info = self._env.step(action)
|
||||
|
||||
if action.ndim == 1:
|
||||
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
|
||||
else:
|
||||
if self.frame_skip > 1:
|
||||
raise NotImplementedError()
|
||||
|
||||
num_action_steps = action.shape[0]
|
||||
for i in range(num_action_steps):
|
||||
raw_obs, reward, done, info = self._env.step(action[i])
|
||||
sum_reward += reward
|
||||
|
||||
obs = self._format_raw_obs(raw_obs)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue.append(obs["image"])
|
||||
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue.append(obs["state"])
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
self.call_rendering_hooks()
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": self._format_raw_obs(raw_obs),
|
||||
|
@ -126,23 +168,36 @@ class SimxarmEnv(EnvBase):
|
|||
def _make_spec(self):
|
||||
obs = {}
|
||||
if self.from_pixels:
|
||||
image_shape = (3, self.image_size, self.image_size)
|
||||
if self.num_prev_obs > 0:
|
||||
image_shape = (self.num_prev_obs + 1, *image_shape)
|
||||
|
||||
obs["image"] = BoundedTensorSpec(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(3, self.image_size, self.image_size),
|
||||
shape=image_shape,
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
if not self.pixels_only:
|
||||
state_shape = (len(self._env.robot_state),)
|
||||
if self.num_prev_obs > 0:
|
||||
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||
|
||||
obs["state"] = UnboundedContinuousTensorSpec(
|
||||
shape=(len(self._env.robot_state),),
|
||||
shape=state_shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
|
||||
state_shape = self._env.observation_space["observation"].shape
|
||||
if self.num_prev_obs > 0:
|
||||
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||
|
||||
obs["state"] = UnboundedContinuousTensorSpec(
|
||||
shape=self._env.observation_space["observation"].shape,
|
||||
# TODO:
|
||||
shape=state_shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
from tensordict import TensorDictBase
|
||||
from tensordict.nn import dispatch
|
||||
from tensordict.utils import NestedKey
|
||||
|
@ -7,19 +8,45 @@ from torchrl.envs.transforms import ObservationTransform, Transform
|
|||
|
||||
|
||||
class Prod(ObservationTransform):
|
||||
invertible = True
|
||||
|
||||
def __init__(self, in_keys: Sequence[NestedKey], prod: float):
|
||||
super().__init__()
|
||||
self.in_keys = in_keys
|
||||
self.prod = prod
|
||||
self.original_dtypes = {}
|
||||
|
||||
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
||||
# _reset is called once when the environment reset to normalize the first observation
|
||||
tensordict_reset = self._call(tensordict_reset)
|
||||
return tensordict_reset
|
||||
|
||||
@dispatch(source="in_keys", dest="out_keys")
|
||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||
return self._call(tensordict)
|
||||
|
||||
def _call(self, td):
|
||||
for key in self.in_keys:
|
||||
td[key] *= self.prod
|
||||
if td.get(key, None) is None:
|
||||
continue
|
||||
self.original_dtypes[key] = td[key].dtype
|
||||
td[key] = td[key].type(torch.float32) * self.prod
|
||||
return td
|
||||
|
||||
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
for key in self.in_keys:
|
||||
if td.get(key, None) is None:
|
||||
continue
|
||||
td[key] = (td[key] / self.prod).type(self.original_dtypes[key])
|
||||
return td
|
||||
|
||||
def transform_observation_spec(self, obs_spec):
|
||||
for key in self.in_keys:
|
||||
obs_spec[key].space.high *= self.prod
|
||||
if obs_spec.get(key, None) is None:
|
||||
continue
|
||||
obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod
|
||||
obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod
|
||||
obs_spec[key].dtype = torch.float32
|
||||
return obs_spec
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
from typing import List
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
|
||||
from .position_encoding import build_position_encoding
|
||||
from .utils import NestedTensor, is_main_process
|
||||
|
||||
|
||||
class FrozenBatchNorm2d(torch.nn.Module):
|
||||
"""
|
||||
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
||||
|
||||
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
||||
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
|
||||
produce nans.
|
||||
"""
|
||||
|
||||
def __init__(self, n):
|
||||
super().__init__()
|
||||
self.register_buffer("weight", torch.ones(n))
|
||||
self.register_buffer("bias", torch.zeros(n))
|
||||
self.register_buffer("running_mean", torch.zeros(n))
|
||||
self.register_buffer("running_var", torch.ones(n))
|
||||
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
num_batches_tracked_key = prefix + "num_batches_tracked"
|
||||
if num_batches_tracked_key in state_dict:
|
||||
del state_dict[num_batches_tracked_key]
|
||||
|
||||
super()._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# move reshapes to the beginning
|
||||
# to make it fuser-friendly
|
||||
w = self.weight.reshape(1, -1, 1, 1)
|
||||
b = self.bias.reshape(1, -1, 1, 1)
|
||||
rv = self.running_var.reshape(1, -1, 1, 1)
|
||||
rm = self.running_mean.reshape(1, -1, 1, 1)
|
||||
eps = 1e-5
|
||||
scale = w * (rv + eps).rsqrt()
|
||||
bias = b - rm * scale
|
||||
return x * scale + bias
|
||||
|
||||
|
||||
class BackboneBase(nn.Module):
|
||||
def __init__(
|
||||
self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool
|
||||
):
|
||||
super().__init__()
|
||||
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
|
||||
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
||||
# parameter.requires_grad_(False)
|
||||
if return_interm_layers:
|
||||
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
||||
else:
|
||||
return_layers = {"layer4": "0"}
|
||||
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
||||
self.num_channels = num_channels
|
||||
|
||||
def forward(self, tensor):
|
||||
xs = self.body(tensor)
|
||||
return xs
|
||||
# out: Dict[str, NestedTensor] = {}
|
||||
# for name, x in xs.items():
|
||||
# m = tensor_list.mask
|
||||
# assert m is not None
|
||||
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
||||
# out[name] = NestedTensor(x, mask)
|
||||
# return out
|
||||
|
||||
|
||||
class Backbone(BackboneBase):
|
||||
"""ResNet backbone with frozen BatchNorm."""
|
||||
|
||||
def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, dilation: bool):
|
||||
backbone = getattr(torchvision.models, name)(
|
||||
replace_stride_with_dilation=[False, False, dilation],
|
||||
pretrained=is_main_process(),
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
) # pretrained # TODO do we want frozen batch_norm??
|
||||
num_channels = 512 if name in ("resnet18", "resnet34") else 2048
|
||||
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
||||
|
||||
|
||||
class Joiner(nn.Sequential):
|
||||
def __init__(self, backbone, position_embedding):
|
||||
super().__init__(backbone, position_embedding)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
xs = self[0](tensor_list)
|
||||
out: List[NestedTensor] = []
|
||||
pos = []
|
||||
for _, x in xs.items():
|
||||
out.append(x)
|
||||
# position encoding
|
||||
pos.append(self[1](x).to(x.dtype))
|
||||
|
||||
return out, pos
|
||||
|
||||
|
||||
def build_backbone(args):
|
||||
position_embedding = build_position_encoding(args)
|
||||
train_backbone = args.lr_backbone > 0
|
||||
return_interm_layers = args.masks
|
||||
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
|
||||
model = Joiner(backbone, position_embedding)
|
||||
model.num_channels = backbone.num_channels
|
||||
return model
|
|
@ -0,0 +1,212 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
|
||||
from .backbone import build_backbone
|
||||
from .transformer import TransformerEncoder, TransformerEncoderLayer, build_transformer
|
||||
|
||||
|
||||
def reparametrize(mu, logvar):
|
||||
std = logvar.div(2).exp()
|
||||
eps = Variable(std.data.new(std.size()).normal_())
|
||||
return mu + std * eps
|
||||
|
||||
|
||||
def get_sinusoid_encoding_table(n_position, d_hid):
|
||||
def get_position_angle_vec(position):
|
||||
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
||||
|
||||
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
|
||||
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
||||
|
||||
|
||||
class DETRVAE(nn.Module):
|
||||
"""This is the DETR module that performs object detection"""
|
||||
|
||||
def __init__(
|
||||
self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names, vae
|
||||
):
|
||||
"""Initializes the model.
|
||||
Parameters:
|
||||
backbones: torch module of the backbone to be used. See backbone.py
|
||||
transformer: torch module of the transformer architecture. See transformer.py
|
||||
state_dim: robot state dimension of the environment
|
||||
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
||||
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
||||
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_queries = num_queries
|
||||
self.camera_names = camera_names
|
||||
self.transformer = transformer
|
||||
self.encoder = encoder
|
||||
self.vae = vae
|
||||
hidden_dim = transformer.d_model
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
self.is_pad_head = nn.Linear(hidden_dim, 1)
|
||||
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
||||
if backbones is not None:
|
||||
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
|
||||
self.backbones = nn.ModuleList(backbones)
|
||||
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||
else:
|
||||
# input_dim = 14 + 7 # robot_state + env_state
|
||||
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||
# TODO(rcadene): understand what is env_state, and why it needs to be 7
|
||||
self.input_proj_env_state = nn.Linear(state_dim // 2, hidden_dim)
|
||||
self.pos = torch.nn.Embedding(2, hidden_dim)
|
||||
self.backbones = None
|
||||
|
||||
# encoder extra parameters
|
||||
self.latent_dim = 32 # final size of latent z # TODO tune
|
||||
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
|
||||
self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding
|
||||
self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
|
||||
self.latent_proj = nn.Linear(
|
||||
hidden_dim, self.latent_dim * 2
|
||||
) # project hidden state to latent std, var
|
||||
self.register_buffer(
|
||||
"pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim)
|
||||
) # [CLS], qpos, a_seq
|
||||
|
||||
# decoder extra parameters
|
||||
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
|
||||
self.additional_pos_embed = nn.Embedding(
|
||||
2, hidden_dim
|
||||
) # learned position embedding for proprio and latent
|
||||
|
||||
def forward(self, qpos, image, env_state, actions=None, is_pad=None):
|
||||
"""
|
||||
qpos: batch, qpos_dim
|
||||
image: batch, num_cam, channel, height, width
|
||||
env_state: None
|
||||
actions: batch, seq, action_dim
|
||||
"""
|
||||
is_training = actions is not None # train or val
|
||||
bs, _ = qpos.shape
|
||||
### Obtain latent z from action sequence
|
||||
if self.vae and is_training:
|
||||
# project action sequence to embedding dim, and concat with a CLS token
|
||||
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
|
||||
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
|
||||
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
|
||||
cls_embed = self.cls_embed.weight # (1, hidden_dim)
|
||||
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
|
||||
encoder_input = torch.cat(
|
||||
[cls_embed, qpos_embed, action_embed], axis=1
|
||||
) # (bs, seq+1, hidden_dim)
|
||||
encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
|
||||
# do not mask cls token
|
||||
# cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
|
||||
# is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
|
||||
# obtain position embedding
|
||||
pos_embed = self.pos_table.clone().detach()
|
||||
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
|
||||
# query model
|
||||
encoder_output = self.encoder(encoder_input, pos=pos_embed) # , src_key_padding_mask=is_pad)
|
||||
encoder_output = encoder_output[0] # take cls output only
|
||||
latent_info = self.latent_proj(encoder_output)
|
||||
mu = latent_info[:, : self.latent_dim]
|
||||
logvar = latent_info[:, self.latent_dim :]
|
||||
latent_sample = reparametrize(mu, logvar)
|
||||
latent_input = self.latent_out_proj(latent_sample)
|
||||
else:
|
||||
mu = logvar = None
|
||||
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
|
||||
latent_input = self.latent_out_proj(latent_sample)
|
||||
|
||||
if self.backbones is not None:
|
||||
# Image observation features and position embeddings
|
||||
all_cam_features = []
|
||||
all_cam_pos = []
|
||||
for cam_id, _ in enumerate(self.camera_names):
|
||||
features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
|
||||
features = features[0] # take the last layer feature
|
||||
pos = pos[0]
|
||||
all_cam_features.append(self.input_proj(features))
|
||||
all_cam_pos.append(pos)
|
||||
# proprioception features
|
||||
proprio_input = self.input_proj_robot_state(qpos)
|
||||
# fold camera dimension into width dimension
|
||||
src = torch.cat(all_cam_features, axis=3)
|
||||
pos = torch.cat(all_cam_pos, axis=3)
|
||||
hs = self.transformer(
|
||||
src,
|
||||
None,
|
||||
self.query_embed.weight,
|
||||
pos,
|
||||
latent_input,
|
||||
proprio_input,
|
||||
self.additional_pos_embed.weight,
|
||||
)[0]
|
||||
else:
|
||||
qpos = self.input_proj_robot_state(qpos)
|
||||
env_state = self.input_proj_env_state(env_state)
|
||||
transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
|
||||
hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
|
||||
a_hat = self.action_head(hs)
|
||||
is_pad_hat = self.is_pad_head(hs)
|
||||
return a_hat, is_pad_hat, [mu, logvar]
|
||||
|
||||
|
||||
def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
|
||||
if hidden_depth == 0:
|
||||
mods = [nn.Linear(input_dim, output_dim)]
|
||||
else:
|
||||
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
|
||||
for _ in range(hidden_depth - 1):
|
||||
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
|
||||
mods.append(nn.Linear(hidden_dim, output_dim))
|
||||
trunk = nn.Sequential(*mods)
|
||||
return trunk
|
||||
|
||||
|
||||
def build_encoder(args):
|
||||
d_model = args.hidden_dim # 256
|
||||
dropout = args.dropout # 0.1
|
||||
nhead = args.nheads # 8
|
||||
dim_feedforward = args.dim_feedforward # 2048
|
||||
num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
|
||||
normalize_before = args.pre_norm # False
|
||||
activation = "relu"
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(
|
||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
||||
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
||||
|
||||
return encoder
|
||||
|
||||
|
||||
def build(args):
|
||||
# From state
|
||||
# backbone = None # from state for now, no need for conv nets
|
||||
# From image
|
||||
backbones = []
|
||||
backbone = build_backbone(args)
|
||||
backbones.append(backbone)
|
||||
|
||||
transformer = build_transformer(args)
|
||||
|
||||
encoder = build_encoder(args)
|
||||
|
||||
model = DETRVAE(
|
||||
backbones,
|
||||
transformer,
|
||||
encoder,
|
||||
state_dim=args.state_dim,
|
||||
action_dim=args.action_dim,
|
||||
num_queries=args.num_queries,
|
||||
camera_names=args.camera_names,
|
||||
vae=args.vae,
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print("number of parameters: {:.2f}M".format(n_parameters / 1e6))
|
||||
|
||||
return model
|
|
@ -0,0 +1,217 @@
|
|||
import logging
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from lerobot.common.policies.act.detr_vae import build
|
||||
|
||||
|
||||
def build_act_model_and_optimizer(cfg):
|
||||
model = build(cfg)
|
||||
|
||||
param_dicts = [
|
||||
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
|
||||
"lr": cfg.lr_backbone,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(param_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay)
|
||||
|
||||
return model, optimizer
|
||||
|
||||
|
||||
def kl_divergence(mu, logvar):
|
||||
batch_size = mu.size(0)
|
||||
assert batch_size != 0
|
||||
if mu.data.ndimension() == 4:
|
||||
mu = mu.view(mu.size(0), mu.size(1))
|
||||
if logvar.data.ndimension() == 4:
|
||||
logvar = logvar.view(logvar.size(0), logvar.size(1))
|
||||
|
||||
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
|
||||
total_kld = klds.sum(1).mean(0, True)
|
||||
dimension_wise_kld = klds.mean(0)
|
||||
mean_kld = klds.mean(1).mean(0, True)
|
||||
|
||||
return total_kld, dimension_wise_kld, mean_kld
|
||||
|
||||
|
||||
class ActionChunkingTransformerPolicy(nn.Module):
|
||||
def __init__(self, cfg, device, n_action_steps=1):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self.n_action_steps = n_action_steps
|
||||
self.device = device
|
||||
self.model, self.optimizer = build_act_model_and_optimizer(cfg)
|
||||
self.kl_weight = self.cfg.kl_weight
|
||||
logging.info(f"KL Weight {self.kl_weight}")
|
||||
self.to(self.device)
|
||||
|
||||
def update(self, replay_buffer, step):
|
||||
del step
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
self.train()
|
||||
|
||||
num_slices = self.cfg.batch_size
|
||||
batch_size = self.cfg.horizon * num_slices
|
||||
|
||||
assert batch_size % self.cfg.horizon == 0
|
||||
assert batch_size % num_slices == 0
|
||||
|
||||
def process_batch(batch, horizon, num_slices):
|
||||
# trajectory t = 64, horizon h = 16
|
||||
# (t h) ... -> t h ...
|
||||
batch = batch.reshape(num_slices, horizon)
|
||||
|
||||
image = batch["observation", "image", "top"]
|
||||
image = image[:, 0] # first observation t=0
|
||||
# batch, num_cam, channel, height, width
|
||||
image = image.unsqueeze(1)
|
||||
assert image.ndim == 5
|
||||
image = image.float()
|
||||
|
||||
state = batch["observation", "state"]
|
||||
state = state[:, 0] # first observation t=0
|
||||
# batch, qpos_dim
|
||||
assert state.ndim == 2
|
||||
|
||||
action = batch["action"]
|
||||
# batch, seq, action_dim
|
||||
assert action.ndim == 3
|
||||
assert action.shape[1] == horizon
|
||||
|
||||
if self.cfg.n_obs_steps > 1:
|
||||
raise NotImplementedError()
|
||||
# # keep first n observations of the slice corresponding to t=[-1,0]
|
||||
# image = image[:, : self.cfg.n_obs_steps]
|
||||
# state = state[:, : self.cfg.n_obs_steps]
|
||||
|
||||
out = {
|
||||
"obs": {
|
||||
"image": image.to(self.device, non_blocking=True),
|
||||
"agent_pos": state.to(self.device, non_blocking=True),
|
||||
},
|
||||
"action": action.to(self.device, non_blocking=True),
|
||||
}
|
||||
return out
|
||||
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
batch = process_batch(batch, self.cfg.horizon, num_slices)
|
||||
|
||||
data_s = time.time() - start_time
|
||||
|
||||
loss = self.compute_loss(batch)
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.cfg.grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
# self.lr_scheduler.step()
|
||||
|
||||
info = {
|
||||
"loss": loss.item(),
|
||||
"grad_norm": float(grad_norm),
|
||||
# "lr": self.lr_scheduler.get_last_lr()[0],
|
||||
"lr": self.cfg.lr,
|
||||
"data_s": data_s,
|
||||
"update_s": time.time() - start_time,
|
||||
}
|
||||
|
||||
return info
|
||||
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
d = torch.load(fp)
|
||||
self.load_state_dict(d)
|
||||
|
||||
def compute_loss(self, batch):
|
||||
loss_dict = self._forward(
|
||||
qpos=batch["obs"]["agent_pos"],
|
||||
image=batch["obs"]["image"],
|
||||
actions=batch["action"],
|
||||
)
|
||||
loss = loss_dict["loss"]
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, observation, step_count):
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
|
||||
self.eval()
|
||||
|
||||
# TODO(rcadene): remove unsqueeze hack to add bsize=1
|
||||
observation["image", "top"] = observation["image", "top"].unsqueeze(0)
|
||||
# observation["state"] = observation["state"].unsqueeze(0)
|
||||
|
||||
# TODO(rcadene): remove hack
|
||||
# add 1 camera dimension
|
||||
observation["image", "top"] = observation["image", "top"].unsqueeze(1)
|
||||
|
||||
obs_dict = {
|
||||
"image": observation["image", "top"],
|
||||
"agent_pos": observation["state"],
|
||||
}
|
||||
action = self._forward(qpos=obs_dict["agent_pos"], image=obs_dict["image"])
|
||||
|
||||
if self.cfg.temporal_agg:
|
||||
# TODO(rcadene): implement temporal aggregation
|
||||
raise NotImplementedError()
|
||||
# all_time_actions[[t], t:t+num_queries] = action
|
||||
# actions_for_curr_step = all_time_actions[:, t]
|
||||
# actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
|
||||
# actions_for_curr_step = actions_for_curr_step[actions_populated]
|
||||
# k = 0.01
|
||||
# exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
|
||||
# exp_weights = exp_weights / exp_weights.sum()
|
||||
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
|
||||
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
|
||||
|
||||
# remove bsize=1
|
||||
action = action.squeeze(0)
|
||||
|
||||
# take first predicted action or n first actions
|
||||
action = action[0] if self.n_action_steps == 1 else action[: self.n_action_steps]
|
||||
return action
|
||||
|
||||
def _forward(self, qpos, image, actions=None, is_pad=None):
|
||||
env_state = None
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
image = normalize(image)
|
||||
|
||||
is_training = actions is not None
|
||||
if is_training: # training time
|
||||
actions = actions[:, : self.model.num_queries]
|
||||
if is_pad is not None:
|
||||
is_pad = is_pad[:, : self.model.num_queries]
|
||||
|
||||
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
|
||||
|
||||
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
|
||||
l1 = all_l1.mean() if is_pad is None else (all_l1 * ~is_pad.unsqueeze(-1)).mean()
|
||||
|
||||
loss_dict = {}
|
||||
loss_dict["l1"] = l1
|
||||
if self.cfg.vae:
|
||||
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
||||
loss_dict["kl"] = total_kld[0]
|
||||
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
|
||||
else:
|
||||
loss_dict["loss"] = loss_dict["l1"]
|
||||
return loss_dict
|
||||
else:
|
||||
action, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
|
||||
return action
|
|
@ -0,0 +1,101 @@
|
|||
"""
|
||||
Various positional encodings for the transformer.
|
||||
"""
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .utils import NestedTensor
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
||||
super().__init__()
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
if scale is not None and normalize is False:
|
||||
raise ValueError("normalize should be True if scale is passed")
|
||||
if scale is None:
|
||||
scale = 2 * math.pi
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, tensor):
|
||||
x = tensor
|
||||
# mask = tensor_list.mask
|
||||
# assert mask is not None
|
||||
# not_mask = ~mask
|
||||
|
||||
not_mask = torch.ones_like(x[0, [0]])
|
||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||
if self.normalize:
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
return pos
|
||||
|
||||
|
||||
class PositionEmbeddingLearned(nn.Module):
|
||||
"""
|
||||
Absolute pos embedding, learned.
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats=256):
|
||||
super().__init__()
|
||||
self.row_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.col_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.uniform_(self.row_embed.weight)
|
||||
nn.init.uniform_(self.col_embed.weight)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
x = tensor_list.tensors
|
||||
h, w = x.shape[-2:]
|
||||
i = torch.arange(w, device=x.device)
|
||||
j = torch.arange(h, device=x.device)
|
||||
x_emb = self.col_embed(i)
|
||||
y_emb = self.row_embed(j)
|
||||
pos = (
|
||||
torch.cat(
|
||||
[
|
||||
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
||||
y_emb.unsqueeze(1).repeat(1, w, 1),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
.permute(2, 0, 1)
|
||||
.unsqueeze(0)
|
||||
.repeat(x.shape[0], 1, 1, 1)
|
||||
)
|
||||
return pos
|
||||
|
||||
|
||||
def build_position_encoding(args):
|
||||
n_steps = args.hidden_dim // 2
|
||||
if args.position_embedding in ("v2", "sine"):
|
||||
# TODO find a better way of exposing other arguments
|
||||
position_embedding = PositionEmbeddingSine(n_steps, normalize=True)
|
||||
elif args.position_embedding in ("v3", "learned"):
|
||||
position_embedding = PositionEmbeddingLearned(n_steps)
|
||||
else:
|
||||
raise ValueError(f"not supported {args.position_embedding}")
|
||||
|
||||
return position_embedding
|
|
@ -0,0 +1,370 @@
|
|||
"""
|
||||
DETR Transformer class.
|
||||
|
||||
Copy-paste from torch.nn.Transformer with modifications:
|
||||
* positional encodings are passed in MHattention
|
||||
* extra LN at the end of encoder is removed
|
||||
* decoder returns a stack of activations from all decoding layers
|
||||
"""
|
||||
import copy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model=512,
|
||||
nhead=8,
|
||||
num_encoder_layers=6,
|
||||
num_decoder_layers=6,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
normalize_before=False,
|
||||
return_intermediate_dec=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(
|
||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
||||
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
||||
|
||||
decoder_layer = TransformerDecoderLayer(
|
||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
decoder_norm = nn.LayerNorm(d_model)
|
||||
self.decoder = TransformerDecoder(
|
||||
decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec
|
||||
)
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
self.d_model = d_model
|
||||
self.nhead = nhead
|
||||
|
||||
def _reset_parameters(self):
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
mask,
|
||||
query_embed,
|
||||
pos_embed,
|
||||
latent_input=None,
|
||||
proprio_input=None,
|
||||
additional_pos_embed=None,
|
||||
):
|
||||
# TODO flatten only when input has H and W
|
||||
if len(src.shape) == 4: # has H and W
|
||||
# flatten NxCxHxW to HWxNxC
|
||||
bs, c, h, w = src.shape
|
||||
src = src.flatten(2).permute(2, 0, 1)
|
||||
pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
# mask = mask.flatten(1)
|
||||
|
||||
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
|
||||
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
|
||||
|
||||
addition_input = torch.stack([latent_input, proprio_input], axis=0)
|
||||
src = torch.cat([addition_input, src], axis=0)
|
||||
else:
|
||||
assert len(src.shape) == 3
|
||||
# flatten NxHWxC to HWxNxC
|
||||
bs, hw, c = src.shape
|
||||
src = src.permute(1, 0, 2)
|
||||
pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
|
||||
tgt = torch.zeros_like(query_embed)
|
||||
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
||||
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed)
|
||||
hs = hs.transpose(1, 2)
|
||||
return hs
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(self, encoder_layer, num_layers, norm=None):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(encoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
output = src
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos)
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(decoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
self.return_intermediate = return_intermediate
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
output = tgt
|
||||
|
||||
intermediate = []
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(
|
||||
output,
|
||||
memory,
|
||||
tgt_mask=tgt_mask,
|
||||
memory_mask=memory_mask,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
pos=pos,
|
||||
query_pos=query_pos,
|
||||
)
|
||||
if self.return_intermediate:
|
||||
intermediate.append(self.norm(output))
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
if self.return_intermediate:
|
||||
intermediate.pop()
|
||||
intermediate.append(output)
|
||||
|
||||
if self.return_intermediate:
|
||||
return torch.stack(intermediate)
|
||||
|
||||
return output.unsqueeze(0)
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_post(
|
||||
self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
q = k = self.with_pos_embed(src, pos)
|
||||
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
|
||||
src = src + self.dropout1(src2)
|
||||
src = self.norm1(src)
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
||||
src = src + self.dropout2(src2)
|
||||
src = self.norm2(src)
|
||||
return src
|
||||
|
||||
def forward_pre(
|
||||
self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
src2 = self.norm1(src)
|
||||
q = k = self.with_pos_embed(src2, pos)
|
||||
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
|
||||
src = src + self.dropout1(src2)
|
||||
src2 = self.norm2(src)
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
||||
src = src + self.dropout2(src2)
|
||||
return src
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
||||
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_post(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
q = k = self.with_pos_embed(tgt, query_pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt = self.norm1(tgt)
|
||||
tgt2 = self.multihead_attn(
|
||||
query=self.with_pos_embed(tgt, query_pos),
|
||||
key=self.with_pos_embed(memory, pos),
|
||||
value=memory,
|
||||
attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask,
|
||||
)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt = self.norm2(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
tgt = self.norm3(tgt)
|
||||
return tgt
|
||||
|
||||
def forward_pre(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt2 = self.norm2(tgt)
|
||||
tgt2 = self.multihead_attn(
|
||||
query=self.with_pos_embed(tgt2, query_pos),
|
||||
key=self.with_pos_embed(memory, pos),
|
||||
value=memory,
|
||||
attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask,
|
||||
)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt2 = self.norm3(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
return tgt
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask,
|
||||
memory_mask,
|
||||
tgt_key_padding_mask,
|
||||
memory_key_padding_mask,
|
||||
pos,
|
||||
query_pos,
|
||||
)
|
||||
return self.forward_post(
|
||||
tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
|
||||
)
|
||||
|
||||
|
||||
def _get_clones(module, n):
|
||||
return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
|
||||
|
||||
|
||||
def build_transformer(args):
|
||||
return Transformer(
|
||||
d_model=args.hidden_dim,
|
||||
dropout=args.dropout,
|
||||
nhead=args.nheads,
|
||||
dim_feedforward=args.dim_feedforward,
|
||||
num_encoder_layers=args.enc_layers,
|
||||
num_decoder_layers=args.dec_layers,
|
||||
normalize_before=args.pre_norm,
|
||||
return_intermediate_dec=True,
|
||||
)
|
||||
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
"""Return an activation function given a string"""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
if activation == "gelu":
|
||||
return F.gelu
|
||||
if activation == "glu":
|
||||
return F.glu
|
||||
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
|
@ -0,0 +1,477 @@
|
|||
"""
|
||||
Misc functions, including distributed helpers.
|
||||
|
||||
Mostly copy-paste from torchvision references.
|
||||
"""
|
||||
import datetime
|
||||
import os
|
||||
import pickle
|
||||
import subprocess
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
||||
import torchvision
|
||||
from packaging import version
|
||||
from torch import Tensor
|
||||
|
||||
if version.parse(torchvision.__version__) < version.parse("0.7"):
|
||||
from torchvision.ops import _new_empty_tensor
|
||||
from torchvision.ops.misc import _output_size
|
||||
|
||||
|
||||
class SmoothedValue:
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20, fmt=None):
|
||||
if fmt is None:
|
||||
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
self.fmt = fmt
|
||||
|
||||
def update(self, value, n=1):
|
||||
self.deque.append(value)
|
||||
self.count += n
|
||||
self.total += value * n
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
"""
|
||||
Warning: does not synchronize the deque!
|
||||
"""
|
||||
if not is_dist_avail_and_initialized():
|
||||
return
|
||||
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
||||
dist.barrier()
|
||||
dist.all_reduce(t)
|
||||
t = t.tolist()
|
||||
self.count = int(t[0])
|
||||
self.total = t[1]
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.median().item()
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||
return d.mean().item()
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
return max(self.deque)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.deque[-1]
|
||||
|
||||
def __str__(self):
|
||||
return self.fmt.format(
|
||||
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
|
||||
)
|
||||
|
||||
|
||||
def all_gather(data):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
||||
Args:
|
||||
data: any picklable object
|
||||
Returns:
|
||||
list[data]: list of data gathered from each rank
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
|
||||
# serialized to a Tensor
|
||||
buffer = pickle.dumps(data)
|
||||
storage = torch.ByteStorage.from_buffer(buffer)
|
||||
tensor = torch.ByteTensor(storage).to("cuda")
|
||||
|
||||
# obtain Tensor size of each rank
|
||||
local_size = torch.tensor([tensor.numel()], device="cuda")
|
||||
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
||||
dist.all_gather(size_list, local_size)
|
||||
size_list = [int(size.item()) for size in size_list]
|
||||
max_size = max(size_list)
|
||||
|
||||
# receiving Tensor from all ranks
|
||||
# we pad the tensor because torch all_gather does not support
|
||||
# gathering tensors of different shapes
|
||||
tensor_list = []
|
||||
for _ in size_list:
|
||||
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
||||
if local_size != max_size:
|
||||
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
||||
tensor = torch.cat((tensor, padding), dim=0)
|
||||
dist.all_gather(tensor_list, tensor)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list, strict=False):
|
||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||
data_list.append(pickle.loads(buffer))
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def reduce_dict(input_dict, average=True):
|
||||
"""
|
||||
Args:
|
||||
input_dict (dict): all the values will be reduced
|
||||
average (bool): whether to do average or sum
|
||||
Reduce the values in the dictionary from all processes so that all processes
|
||||
have the averaged results. Returns a dict with the same fields as
|
||||
input_dict, after reduction.
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size < 2:
|
||||
return input_dict
|
||||
with torch.no_grad():
|
||||
names = []
|
||||
values = []
|
||||
# sort the keys so that they are consistent across processes
|
||||
for k in sorted(input_dict.keys()):
|
||||
names.append(k)
|
||||
values.append(input_dict[k])
|
||||
values = torch.stack(values, dim=0)
|
||||
dist.all_reduce(values)
|
||||
if average:
|
||||
values /= world_size
|
||||
reduced_dict = {k: v for k, v in zip(names, values, strict=False)} # noqa: C416
|
||||
return reduced_dict
|
||||
|
||||
|
||||
class MetricLogger:
|
||||
def __init__(self, delimiter="\t"):
|
||||
self.meters = defaultdict(SmoothedValue)
|
||||
self.delimiter = delimiter
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
assert isinstance(v, (float, int))
|
||||
self.meters[k].update(v)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.meters:
|
||||
return self.meters[attr]
|
||||
if attr in self.__dict__:
|
||||
return self.__dict__[attr]
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
|
||||
|
||||
def __str__(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
loss_str.append("{}: {}".format(name, str(meter)))
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
for meter in self.meters.values():
|
||||
meter.synchronize_between_processes()
|
||||
|
||||
def add_meter(self, name, meter):
|
||||
self.meters[name] = meter
|
||||
|
||||
def log_every(self, iterable, print_freq, header=None):
|
||||
if not header:
|
||||
header = ""
|
||||
start_time = time.time()
|
||||
end = time.time()
|
||||
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
||||
data_time = SmoothedValue(fmt="{avg:.4f}")
|
||||
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
||||
if torch.cuda.is_available():
|
||||
log_msg = self.delimiter.join(
|
||||
[
|
||||
header,
|
||||
"[{0" + space_fmt + "}/{1}]",
|
||||
"eta: {eta}",
|
||||
"{meters}",
|
||||
"time: {time}",
|
||||
"data: {data}",
|
||||
"max mem: {memory:.0f}",
|
||||
]
|
||||
)
|
||||
else:
|
||||
log_msg = self.delimiter.join(
|
||||
[
|
||||
header,
|
||||
"[{0" + space_fmt + "}/{1}]",
|
||||
"eta: {eta}",
|
||||
"{meters}",
|
||||
"time: {time}",
|
||||
"data: {data}",
|
||||
]
|
||||
)
|
||||
mega_b = 1024.0 * 1024.0
|
||||
for i, obj in enumerate(iterable):
|
||||
data_time.update(time.time() - end)
|
||||
yield obj
|
||||
iter_time.update(time.time() - end)
|
||||
if i % print_freq == 0 or i == len(iterable) - 1:
|
||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
if torch.cuda.is_available():
|
||||
print(
|
||||
log_msg.format(
|
||||
i,
|
||||
len(iterable),
|
||||
eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time),
|
||||
data=str(data_time),
|
||||
memory=torch.cuda.max_memory_allocated() / mega_b,
|
||||
)
|
||||
)
|
||||
else:
|
||||
print(
|
||||
log_msg.format(
|
||||
i,
|
||||
len(iterable),
|
||||
eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time),
|
||||
data=str(data_time),
|
||||
)
|
||||
)
|
||||
end = time.time()
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable)))
|
||||
|
||||
|
||||
def get_sha():
|
||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
def _run(command):
|
||||
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
|
||||
|
||||
sha = "N/A"
|
||||
diff = "clean"
|
||||
branch = "N/A"
|
||||
try:
|
||||
sha = _run(["git", "rev-parse", "HEAD"])
|
||||
subprocess.check_output(["git", "diff"], cwd=cwd)
|
||||
diff = _run(["git", "diff-index", "HEAD"])
|
||||
diff = "has uncommited changes" if diff else "clean"
|
||||
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
||||
except Exception:
|
||||
pass
|
||||
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
||||
return message
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
batch = list(zip(*batch, strict=False))
|
||||
batch[0] = nested_tensor_from_tensor_list(batch[0])
|
||||
return tuple(batch)
|
||||
|
||||
|
||||
def _max_by_axis(the_list):
|
||||
# type: (List[List[int]]) -> List[int]
|
||||
maxes = the_list[0]
|
||||
for sublist in the_list[1:]:
|
||||
for index, item in enumerate(sublist):
|
||||
maxes[index] = max(maxes[index], item)
|
||||
return maxes
|
||||
|
||||
|
||||
class NestedTensor:
|
||||
def __init__(self, tensors, mask: Optional[Tensor]):
|
||||
self.tensors = tensors
|
||||
self.mask = mask
|
||||
|
||||
def to(self, device):
|
||||
# type: (Device) -> NestedTensor # noqa
|
||||
cast_tensor = self.tensors.to(device)
|
||||
mask = self.mask
|
||||
if mask is not None:
|
||||
assert mask is not None
|
||||
cast_mask = mask.to(device)
|
||||
else:
|
||||
cast_mask = None
|
||||
return NestedTensor(cast_tensor, cast_mask)
|
||||
|
||||
def decompose(self):
|
||||
return self.tensors, self.mask
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.tensors)
|
||||
|
||||
|
||||
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
||||
# TODO make this more general
|
||||
if tensor_list[0].ndim == 3:
|
||||
if torchvision._is_tracing():
|
||||
# nested_tensor_from_tensor_list() does not export well to ONNX
|
||||
# call _onnx_nested_tensor_from_tensor_list() instead
|
||||
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
||||
|
||||
# TODO make it support different-sized images
|
||||
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
||||
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
||||
batch_shape = [len(tensor_list)] + max_size
|
||||
b, c, h, w = batch_shape
|
||||
dtype = tensor_list[0].dtype
|
||||
device = tensor_list[0].device
|
||||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
||||
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
||||
for img, pad_img, m in zip(tensor_list, tensor, mask, strict=False):
|
||||
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
m[: img.shape[1], : img.shape[2]] = False
|
||||
else:
|
||||
raise ValueError("not supported")
|
||||
return NestedTensor(tensor, mask)
|
||||
|
||||
|
||||
# _onnx_nested_tensor_from_tensor_list() is an implementation of
|
||||
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
|
||||
@torch.jit.unused
|
||||
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
||||
max_size = []
|
||||
for i in range(tensor_list[0].dim()):
|
||||
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(
|
||||
torch.int64
|
||||
)
|
||||
max_size.append(max_size_i)
|
||||
max_size = tuple(max_size)
|
||||
|
||||
# work around for
|
||||
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
# m[: img.shape[1], :img.shape[2]] = False
|
||||
# which is not yet supported in onnx
|
||||
padded_imgs = []
|
||||
padded_masks = []
|
||||
for img in tensor_list:
|
||||
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape), strict=False)]
|
||||
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
||||
padded_imgs.append(padded_img)
|
||||
|
||||
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
||||
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
||||
padded_masks.append(padded_mask.to(torch.bool))
|
||||
|
||||
tensor = torch.stack(padded_imgs)
|
||||
mask = torch.stack(padded_masks)
|
||||
|
||||
return NestedTensor(tensor, mask=mask)
|
||||
|
||||
|
||||
def setup_for_distributed(is_master):
|
||||
"""
|
||||
This function disables printing when not in master process
|
||||
"""
|
||||
import builtins as __builtin__
|
||||
|
||||
builtin_print = __builtin__.print
|
||||
|
||||
def print(*args, **kwargs):
|
||||
force = kwargs.pop("force", False)
|
||||
if is_master or force:
|
||||
builtin_print(*args, **kwargs)
|
||||
|
||||
__builtin__.print = print
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def save_on_master(*args, **kwargs):
|
||||
if is_main_process():
|
||||
torch.save(*args, **kwargs)
|
||||
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
args.world_size = int(os.environ["WORLD_SIZE"])
|
||||
args.gpu = int(os.environ["LOCAL_RANK"])
|
||||
elif "SLURM_PROCID" in os.environ:
|
||||
args.rank = int(os.environ["SLURM_PROCID"])
|
||||
args.gpu = args.rank % torch.cuda.device_count()
|
||||
else:
|
||||
print("Not using distributed mode")
|
||||
args.distributed = False
|
||||
return
|
||||
|
||||
args.distributed = True
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
args.dist_backend = "nccl"
|
||||
print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
|
||||
torch.distributed.init_process_group(
|
||||
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
|
||||
)
|
||||
torch.distributed.barrier()
|
||||
setup_for_distributed(args.rank == 0)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
if target.numel() == 0:
|
||||
return [torch.zeros([], device=output.device)]
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
||||
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
|
||||
"""
|
||||
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
||||
This will eventually be supported natively by PyTorch, and this
|
||||
class can go away.
|
||||
"""
|
||||
if version.parse(torchvision.__version__) < version.parse("0.7"):
|
||||
if input.numel() > 0:
|
||||
return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners)
|
||||
|
||||
output_shape = _output_size(2, input, size, scale_factor)
|
||||
output_shape = list(input.shape[:-2]) + list(output_shape)
|
||||
return _new_empty_tensor(input, output_shape)
|
||||
else:
|
||||
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
|
|
@ -17,6 +17,12 @@ def make_policy(cfg):
|
|||
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
||||
**cfg.policy,
|
||||
)
|
||||
elif cfg.policy.name == "act":
|
||||
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
|
||||
|
||||
policy = ActionChunkingTransformerPolicy(
|
||||
cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps
|
||||
)
|
||||
else:
|
||||
raise ValueError(cfg.policy.name)
|
||||
|
||||
|
|
|
@ -15,11 +15,11 @@ env:
|
|||
task: sim_insertion_human
|
||||
from_pixels: True
|
||||
pixels_only: False
|
||||
image_size: 96
|
||||
image_size: [3, 480, 640]
|
||||
action_repeat: 1
|
||||
episode_length: 300
|
||||
episode_length: 400
|
||||
fps: ${fps}
|
||||
|
||||
policy:
|
||||
state_dim: 2
|
||||
action_dim: 2
|
||||
state_dim: 14
|
||||
action_dim: 14
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
# @package _global_
|
||||
|
||||
offline_steps: 1344000
|
||||
online_steps: 0
|
||||
|
||||
eval_episodes: 1
|
||||
eval_freq: 10000
|
||||
save_freq: 100000
|
||||
log_freq: 250
|
||||
|
||||
horizon: 100
|
||||
n_obs_steps: 1
|
||||
n_latency_steps: 0
|
||||
# when temporal_agg=False, n_action_steps=horizon
|
||||
n_action_steps: ${horizon}
|
||||
|
||||
policy:
|
||||
name: act
|
||||
|
||||
pretrained_model_path:
|
||||
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
backbone: resnet18
|
||||
num_queries: ${horizon} # chunk_size
|
||||
horizon: ${horizon} # chunk_size
|
||||
kl_weight: 10
|
||||
hidden_dim: 512
|
||||
dim_feedforward: 3200
|
||||
enc_layers: 4
|
||||
dec_layers: 7
|
||||
nheads: 8
|
||||
#camera_names: [top, front_close, left_pillar, right_pillar]
|
||||
camera_names: [top]
|
||||
position_embedding: sine
|
||||
masks: false
|
||||
dilation: false
|
||||
dropout: 0.1
|
||||
pre_norm: false
|
||||
|
||||
vae: true
|
||||
|
||||
batch_size: 8
|
||||
|
||||
per_alpha: 0.6
|
||||
per_beta: 0.4
|
||||
|
||||
balanced_sampling: false
|
||||
utd: 1
|
||||
|
||||
n_obs_steps: ${n_obs_steps}
|
||||
|
||||
temporal_agg: false
|
||||
|
||||
state_dim: ???
|
||||
action_dim: ???
|
|
@ -1,22 +0,0 @@
|
|||
# TODO(rcadene): obsolete remove
|
||||
import os
|
||||
import zipfile
|
||||
|
||||
import gdown
|
||||
|
||||
|
||||
def download():
|
||||
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
||||
download_path = "data.zip"
|
||||
gdown.download(url, download_path, quiet=False)
|
||||
print("Extracting...")
|
||||
with zipfile.ZipFile(download_path, "r") as zip_f:
|
||||
for member in zip_f.namelist():
|
||||
if member.startswith("data/xarm") and member.endswith(".pkl"):
|
||||
print(member)
|
||||
zip_f.extract(member=member)
|
||||
os.remove(download_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
download()
|
|
@ -38,27 +38,18 @@ def eval_policy(
|
|||
successes = []
|
||||
threads = []
|
||||
for i in tqdm.tqdm(range(num_episodes)):
|
||||
tensordict = env.reset()
|
||||
|
||||
ep_frames = []
|
||||
|
||||
if save_video or (return_first_video and i == 0):
|
||||
|
||||
def rendering_callback(env, td=None):
|
||||
def render_frame(env):
|
||||
ep_frames.append(env.render()) # noqa: B023
|
||||
|
||||
# render first frame before rollout
|
||||
rendering_callback(env)
|
||||
else:
|
||||
rendering_callback = None
|
||||
env.register_rendering_hook(render_frame)
|
||||
|
||||
with torch.inference_mode():
|
||||
rollout = env.rollout(
|
||||
max_steps=max_steps,
|
||||
policy=policy,
|
||||
callback=rendering_callback,
|
||||
auto_reset=False,
|
||||
tensordict=tensordict,
|
||||
auto_cast_to_device=True,
|
||||
)
|
||||
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
|
||||
|
@ -85,6 +76,8 @@ def eval_policy(
|
|||
if return_first_video and i == 0:
|
||||
first_video = stacked_frames.transpose(0, 3, 1, 2)
|
||||
|
||||
env.reset_rendering_hooks()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
|
@ -192,6 +193,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
num_episodes=cfg.eval_episodes,
|
||||
max_steps=cfg.env.episode_length // cfg.n_action_steps,
|
||||
return_first_video=True,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
save_video=True,
|
||||
)
|
||||
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
|
||||
if cfg.wandb.enable:
|
||||
|
|
|
@ -340,69 +340,69 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "cython"
|
||||
version = "3.0.8"
|
||||
version = "3.0.9"
|
||||
description = "The Cython compiler for writing C extensions in the Python language."
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
||||
files = [
|
||||
{file = "Cython-3.0.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a846e0a38e2b24e9a5c5dc74b0e54c6e29420d88d1dafabc99e0fc0f3e338636"},
|
||||
{file = "Cython-3.0.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45523fdc2b78d79b32834cc1cc12dc2ca8967af87e22a3ee1bff20e77c7f5520"},
|
||||
{file = "Cython-3.0.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa0b7f3f841fe087410cab66778e2d3fb20ae2d2078a2be3dffe66c6574be39"},
|
||||
{file = "Cython-3.0.8-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e87294e33e40c289c77a135f491cd721bd089f193f956f7b8ed5aa2d0b8c558f"},
|
||||
{file = "Cython-3.0.8-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:a1df7a129344b1215c20096d33c00193437df1a8fcca25b71f17c23b1a44f782"},
|
||||
{file = "Cython-3.0.8-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:13c2a5e57a0358da467d97667297bf820b62a1a87ae47c5f87938b9bb593acbd"},
|
||||
{file = "Cython-3.0.8-cp310-cp310-win32.whl", hash = "sha256:96b028f044f5880e3cb18ecdcfc6c8d3ce9d0af28418d5ab464509f26d8adf12"},
|
||||
{file = "Cython-3.0.8-cp310-cp310-win_amd64.whl", hash = "sha256:8140597a8b5cc4f119a1190f5a2228a84f5ca6d8d9ec386cfce24663f48b2539"},
|
||||
{file = "Cython-3.0.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:aae26f9663e50caf9657148403d9874eea41770ecdd6caf381d177c2b1bb82ba"},
|
||||
{file = "Cython-3.0.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:547eb3cdb2f8c6f48e6865d5a741d9dd051c25b3ce076fbca571727977b28ac3"},
|
||||
{file = "Cython-3.0.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5a567d4b9ba70b26db89d75b243529de9e649a2f56384287533cf91512705bee"},
|
||||
{file = "Cython-3.0.8-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:51d1426263b0e82fb22bda8ea60dc77a428581cc19e97741011b938445d383f1"},
|
||||
{file = "Cython-3.0.8-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c26daaeccda072459b48d211415fd1e5507c06bcd976fa0d5b8b9f1063467d7b"},
|
||||
{file = "Cython-3.0.8-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:289ce7838208211cd166e975865fd73b0649bf118170b6cebaedfbdaf4a37795"},
|
||||
{file = "Cython-3.0.8-cp311-cp311-win32.whl", hash = "sha256:c8aa05f5e17f8042a3be052c24f2edc013fb8af874b0bf76907d16c51b4e7871"},
|
||||
{file = "Cython-3.0.8-cp311-cp311-win_amd64.whl", hash = "sha256:000dc9e135d0eec6ecb2b40a5b02d0868a2f8d2e027a41b0fe16a908a9e6de02"},
|
||||
{file = "Cython-3.0.8-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:90d3fe31db55685d8cb97d43b0ec39ef614fcf660f83c77ed06aa670cb0e164f"},
|
||||
{file = "Cython-3.0.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e24791ddae2324e88e3c902a765595c738f19ae34ee66bfb1a6dac54b1833419"},
|
||||
{file = "Cython-3.0.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2f020fa1c0552052e0660790b8153b79e3fc9a15dbd8f1d0b841fe5d204a6ae6"},
|
||||
{file = "Cython-3.0.8-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18bfa387d7a7f77d7b2526af69a65dbd0b731b8d941aaff5becff8e21f6d7717"},
|
||||
{file = "Cython-3.0.8-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:fe81b339cffd87c0069c6049b4d33e28bdd1874625ee515785bf42c9fdff3658"},
|
||||
{file = "Cython-3.0.8-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:80fd94c076e1e1b1ee40a309be03080b75f413e8997cddcf401a118879863388"},
|
||||
{file = "Cython-3.0.8-cp312-cp312-win32.whl", hash = "sha256:85077915a93e359a9b920280d214dc0cf8a62773e1f3d7d30fab8ea4daed670c"},
|
||||
{file = "Cython-3.0.8-cp312-cp312-win_amd64.whl", hash = "sha256:0cb2dcc565c7851f75d496f724a384a790fab12d1b82461b663e66605bec429a"},
|
||||
{file = "Cython-3.0.8-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:870d2a0a7e3cbd5efa65aecdb38d715ea337a904ea7bb22324036e78fb7068e7"},
|
||||
{file = "Cython-3.0.8-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7e8f2454128974905258d86534f4fd4f91d2f1343605657ecab779d80c9d6d5e"},
|
||||
{file = "Cython-3.0.8-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1949d6aa7bc792554bee2b67a9fe41008acbfe22f4f8df7b6ec7b799613a4b3"},
|
||||
{file = "Cython-3.0.8-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c9f2c6e1b8f3bcd6cb230bac1843f85114780bb8be8614855b1628b36bb510e0"},
|
||||
{file = "Cython-3.0.8-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:05d7eddc668ae7993643f32c7661f25544e791edb745758672ea5b1a82ecffa6"},
|
||||
{file = "Cython-3.0.8-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:bfabe115deef4ada5d23c87bddb11289123336dcc14347011832c07db616dd93"},
|
||||
{file = "Cython-3.0.8-cp36-cp36m-win32.whl", hash = "sha256:0c38c9f0bcce2df0c3347285863621be904ac6b64c5792d871130569d893efd7"},
|
||||
{file = "Cython-3.0.8-cp36-cp36m-win_amd64.whl", hash = "sha256:6c46939c3983217d140999de7c238c3141f56b1ea349e47ca49cae899969aa2c"},
|
||||
{file = "Cython-3.0.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:115f0a50f752da6c99941b103b5cb090da63eb206abbc7c2ad33856ffc73f064"},
|
||||
{file = "Cython-3.0.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c9c0f29246734561c90f36e70ed0506b61aa3d044e4cc4cba559065a2a741fae"},
|
||||
{file = "Cython-3.0.8-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ab75242869ff71e5665fe5c96f3378e79e792fa3c11762641b6c5afbbbbe026"},
|
||||
{file = "Cython-3.0.8-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6717c06e9cfc6c1df18543cd31a21f5d8e378a40f70c851fa2d34f0597037abc"},
|
||||
{file = "Cython-3.0.8-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:9d3f74388db378a3c6fd06e79a809ed98df3f56484d317b81ee762dbf3c263e0"},
|
||||
{file = "Cython-3.0.8-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ae7ac561fd8253a9ae96311e91d12af5f701383564edc11d6338a7b60b285a6f"},
|
||||
{file = "Cython-3.0.8-cp37-cp37m-win32.whl", hash = "sha256:97b2a45845b993304f1799664fa88da676ee19442b15fdcaa31f9da7e1acc434"},
|
||||
{file = "Cython-3.0.8-cp37-cp37m-win_amd64.whl", hash = "sha256:9e2be2b340fea46fb849d378f9b80d3c08ff2e81e2bfbcdb656e2e3cd8c6b2dc"},
|
||||
{file = "Cython-3.0.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2cde23c555470db3f149ede78b518e8274853745289c956a0e06ad8d982e4db9"},
|
||||
{file = "Cython-3.0.8-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7990ca127e1f1beedaf8fc8bf66541d066ef4723ad7d8d47a7cbf842e0f47580"},
|
||||
{file = "Cython-3.0.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b983c8e6803f016146c26854d9150ddad5662960c804ea7f0c752c9266752f0"},
|
||||
{file = "Cython-3.0.8-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a973268d7ca1a2bdf78575e459a94a78e1a0a9bb62a7db0c50041949a73b02ff"},
|
||||
{file = "Cython-3.0.8-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:61a237bc9dd23c7faef0fcfce88c11c65d0c9bb73c74ccfa408b3a012073c20e"},
|
||||
{file = "Cython-3.0.8-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:3a3d67f079598af49e90ff9655bf85bd358f093d727eb21ca2708f467c489cae"},
|
||||
{file = "Cython-3.0.8-cp38-cp38-win32.whl", hash = "sha256:17a642bb01a693e34c914106566f59844b4461665066613913463a719e0dd15d"},
|
||||
{file = "Cython-3.0.8-cp38-cp38-win_amd64.whl", hash = "sha256:2cdfc32252f3b6dc7c94032ab744dcedb45286733443c294d8f909a4854e7f83"},
|
||||
{file = "Cython-3.0.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fa97893d99385386925d00074654aeae3a98867f298d1e12ceaf38a9054a9bae"},
|
||||
{file = "Cython-3.0.8-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f05c0bf9d085c031df8f583f0d506aa3be1692023de18c45d0aaf78685bbb944"},
|
||||
{file = "Cython-3.0.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de892422582f5758bd8de187e98ac829330ec1007bc42c661f687792999988a7"},
|
||||
{file = "Cython-3.0.8-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:314f2355a1f1d06e3c431eaad4708cf10037b5e91e4b231d89c913989d0bdafd"},
|
||||
{file = "Cython-3.0.8-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:78825a3774211e7d5089730f00cdf7f473042acc9ceb8b9eeebe13ed3a5541de"},
|
||||
{file = "Cython-3.0.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:df8093deabc55f37028190cf5e575c26aad23fc673f34b85d5f45076bc37ce39"},
|
||||
{file = "Cython-3.0.8-cp39-cp39-win32.whl", hash = "sha256:1aca1b97e0095b3a9a6c33eada3f661a4ed0d499067d121239b193e5ba3bb4f0"},
|
||||
{file = "Cython-3.0.8-cp39-cp39-win_amd64.whl", hash = "sha256:16873d78be63bd38ffb759da7ab82814b36f56c769ee02b1d5859560e4c3ac3c"},
|
||||
{file = "Cython-3.0.8-py2.py3-none-any.whl", hash = "sha256:171b27051253d3f9108e9759e504ba59ff06e7f7ba944457f94deaf9c21bf0b6"},
|
||||
{file = "Cython-3.0.8.tar.gz", hash = "sha256:8333423d8fd5765e7cceea3a9985dd1e0a5dfeb2734629e1a2ed2d6233d39de6"},
|
||||
{file = "Cython-3.0.9-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:296bd30d4445ac61b66c9d766567f6e81a6e262835d261e903c60c891a6729d3"},
|
||||
{file = "Cython-3.0.9-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f496b52845cb45568a69d6359a2c335135233003e708ea02155c10ce3548aa89"},
|
||||
{file = "Cython-3.0.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:858c3766b9aa3ab8a413392c72bbab1c144a9766b7c7bfdef64e2e414363fa0c"},
|
||||
{file = "Cython-3.0.9-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c0eb1e6ef036028a52525fd9a012a556f6dd4788a0e8755fe864ba0e70cde2ff"},
|
||||
{file = "Cython-3.0.9-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c8191941073ea5896321de3c8c958fd66e5f304b0cd1f22c59edd0b86c4dd90d"},
|
||||
{file = "Cython-3.0.9-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e32b016030bc72a8a22a1f21f470a2f57573761a4f00fbfe8347263f4fbdb9f1"},
|
||||
{file = "Cython-3.0.9-cp310-cp310-win32.whl", hash = "sha256:d6f3ff1cd6123973fe03e0fb8ee936622f976c0c41138969975824d08886572b"},
|
||||
{file = "Cython-3.0.9-cp310-cp310-win_amd64.whl", hash = "sha256:56f3b643dbe14449248bbeb9a63fe3878a24256664bc8c8ef6efd45d102596d8"},
|
||||
{file = "Cython-3.0.9-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:35e6665a20d6b8a152d72b7fd87dbb2af6bb6b18a235b71add68122d594dbd41"},
|
||||
{file = "Cython-3.0.9-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f92f4960c40ad027bd8c364c50db11104eadc59ffeb9e5b7f605ca2f05946e20"},
|
||||
{file = "Cython-3.0.9-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38df37d0e732fbd9a2fef898788492e82b770c33d1e4ed12444bbc8a3b3f89c0"},
|
||||
{file = "Cython-3.0.9-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ad7fd88ebaeaf2e76fd729a8919fae80dab3d6ac0005e28494261d52ff347a8f"},
|
||||
{file = "Cython-3.0.9-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1365d5f76bf4d19df3d19ce932584c9bb76e9fb096185168918ef9b36e06bfa4"},
|
||||
{file = "Cython-3.0.9-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c232e7f279388ac9625c3e5a5a9f0078a9334959c5d6458052c65bbbba895e1e"},
|
||||
{file = "Cython-3.0.9-cp311-cp311-win32.whl", hash = "sha256:357e2fad46a25030b0c0496487e01a9dc0fdd0c09df0897f554d8ba3c1bc4872"},
|
||||
{file = "Cython-3.0.9-cp311-cp311-win_amd64.whl", hash = "sha256:1315aee506506e8d69cf6631d8769e6b10131fdcc0eb66df2698f2a3ddaeeff2"},
|
||||
{file = "Cython-3.0.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:157973807c2796addbed5fbc4d9c882ab34bbc60dc297ca729504901479d5df7"},
|
||||
{file = "Cython-3.0.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00b105b5d050645dd59e6767bc0f18b48a4aa11c85f42ec7dd8181606f4059e3"},
|
||||
{file = "Cython-3.0.9-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac5536d09bef240cae0416d5a703d298b74c7bbc397da803ac9d344e732d4369"},
|
||||
{file = "Cython-3.0.9-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09c44501d476d16aaa4cbc29c87f8c0f54fc20e69b650d59cbfa4863426fc70c"},
|
||||
{file = "Cython-3.0.9-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:cc9c3b9f20d8e298618e5ccd32083ca386e785b08f9893fbec4c50b6b85be772"},
|
||||
{file = "Cython-3.0.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a30d96938c633e3ec37000ac3796525da71254ef109e66bdfd78f29891af6454"},
|
||||
{file = "Cython-3.0.9-cp312-cp312-win32.whl", hash = "sha256:757ca93bdd80702546df4d610d2494ef2e74249cac4d5ba9464589fb464bd8a3"},
|
||||
{file = "Cython-3.0.9-cp312-cp312-win_amd64.whl", hash = "sha256:1dc320a9905ab95414013f6de805efbff9e17bb5fb3b90bbac533f017bec8136"},
|
||||
{file = "Cython-3.0.9-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:4ae349960ebe0da0d33724eaa7f1eb866688fe5434cc67ce4dbc06d6a719fbfc"},
|
||||
{file = "Cython-3.0.9-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63d2537bf688247f76ded6dee28ebd26274f019309aef1eb4f2f9c5c482fde2d"},
|
||||
{file = "Cython-3.0.9-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36f5a2dfc724bea1f710b649f02d802d80fc18320c8e6396684ba4a48412445a"},
|
||||
{file = "Cython-3.0.9-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:deaf4197d4b0bcd5714a497158ea96a2bd6d0f9636095437448f7e06453cc83d"},
|
||||
{file = "Cython-3.0.9-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:000af6deb7412eb7ac0c635ff5e637fb8725dd0a7b88cc58dfc2b3de14e701c4"},
|
||||
{file = "Cython-3.0.9-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:15c7f5c2d35bed9aa5f2a51eaac0df23ae72f2dbacf62fc672dd6bfaa75d2d6f"},
|
||||
{file = "Cython-3.0.9-cp36-cp36m-win32.whl", hash = "sha256:f49aa4970cd3bec66ac22e701def16dca2a49c59cceba519898dd7526e0be2c0"},
|
||||
{file = "Cython-3.0.9-cp36-cp36m-win_amd64.whl", hash = "sha256:4558814fa025b193058d42eeee498a53d6b04b2980d01339fc2444b23fd98e58"},
|
||||
{file = "Cython-3.0.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:539cd1d74fd61f6cfc310fa6bbbad5adc144627f2b7486a07075d4e002fd6aad"},
|
||||
{file = "Cython-3.0.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3232926cd406ee02eabb732206f6e882c3aed9d58f0fea764013d9240405bcf"},
|
||||
{file = "Cython-3.0.9-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33b6ac376538a7fc8c567b85d3c71504308a9318702ec0485dd66c059f3165cb"},
|
||||
{file = "Cython-3.0.9-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2cc92504b5d22ac66031ffb827bd3a967fc75a5f0f76ab48bce62df19be6fdfd"},
|
||||
{file = "Cython-3.0.9-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:22b8fae756c5c0d8968691bed520876de452f216c28ec896a00739a12dba3bd9"},
|
||||
{file = "Cython-3.0.9-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:9cda0d92a09f3520f29bd91009f1194ba9600777c02c30c6d2d4ac65fb63e40d"},
|
||||
{file = "Cython-3.0.9-cp37-cp37m-win32.whl", hash = "sha256:ec612418490941ed16c50c8d3784c7bdc4c4b2a10c361259871790b02ec8c1db"},
|
||||
{file = "Cython-3.0.9-cp37-cp37m-win_amd64.whl", hash = "sha256:976c8d2bedc91ff6493fc973d38b2dc01020324039e2af0e049704a8e1b22936"},
|
||||
{file = "Cython-3.0.9-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5055988b007c92256b6e9896441c3055556038c3497fcbf8c921a6c1fce90719"},
|
||||
{file = "Cython-3.0.9-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9360606d964c2d0492a866464efcf9d0a92715644eede3f6a2aa696de54a137"},
|
||||
{file = "Cython-3.0.9-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02c6e809f060bed073dc7cba1648077fe3b68208863d517c8b39f3920eecf9dd"},
|
||||
{file = "Cython-3.0.9-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:95ed792c966f969cea7489c32ff90150b415c1f3567db8d5a9d489c7c1602dac"},
|
||||
{file = "Cython-3.0.9-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8edd59d22950b400b03ca78d27dc694d2836a92ef0cac4f64cb4b2ff902f7e25"},
|
||||
{file = "Cython-3.0.9-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:4cf0ed273bf60e97922fcbbdd380c39693922a597760160b4b4355e6078ca188"},
|
||||
{file = "Cython-3.0.9-cp38-cp38-win32.whl", hash = "sha256:5eb9bd4ae12ebb2bc79a193d95aacf090fbd8d7013e11ed5412711650cb34934"},
|
||||
{file = "Cython-3.0.9-cp38-cp38-win_amd64.whl", hash = "sha256:44457279da56e0f829bb1fc5a5dc0836e5d498dbcf9b2324f32f7cc9d2ec6569"},
|
||||
{file = "Cython-3.0.9-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c4b419a1adc2af43f4660e2f6eaf1e4fac2dbac59490771eb8ac3d6063f22356"},
|
||||
{file = "Cython-3.0.9-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f836192140f033b2319a0128936367c295c2b32e23df05b03b672a6015757ea"},
|
||||
{file = "Cython-3.0.9-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fd198c1a7f8e9382904d622cc0efa3c184605881fd5262c64cbb7168c4c1ec5"},
|
||||
{file = "Cython-3.0.9-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a274fe9ca5c53fafbcf5c8f262f8ad6896206a466f0eeb40aaf36a7951e957c0"},
|
||||
{file = "Cython-3.0.9-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:158c38360bbc5063341b1e78d3737f1251050f89f58a3df0d10fb171c44262be"},
|
||||
{file = "Cython-3.0.9-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8bf30b045f7deda0014b042c1b41c1d272facc762ab657529e3b05505888e878"},
|
||||
{file = "Cython-3.0.9-cp39-cp39-win32.whl", hash = "sha256:9a001fd95c140c94d934078544ff60a3c46aca2dc86e75a76e4121d3cd1f4b33"},
|
||||
{file = "Cython-3.0.9-cp39-cp39-win_amd64.whl", hash = "sha256:530c01c4aebba709c0ec9c7ecefe07177d0b9fd7ffee29450a118d92192ccbdf"},
|
||||
{file = "Cython-3.0.9-py2.py3-none-any.whl", hash = "sha256:bf96417714353c5454c2e3238fca9338599330cf51625cdc1ca698684465646f"},
|
||||
{file = "Cython-3.0.9.tar.gz", hash = "sha256:a2d354f059d1f055d34cfaa62c5b68bc78ac2ceab6407148d47fb508cf3ba4f3"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -488,6 +488,37 @@ files = [
|
|||
{file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dm-control"
|
||||
version = "1.0.14"
|
||||
description = "Continuous control environments and MuJoCo Python bindings."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "dm_control-1.0.14-py3-none-any.whl", hash = "sha256:883c63244a7ebf598700a97564ed19fffd3479ca79efd090aed881609cdb9fc6"},
|
||||
{file = "dm_control-1.0.14.tar.gz", hash = "sha256:def1ece747b6f175c581150826b50f1a6134086dab34f8f3fd2d088ea035cf3d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
absl-py = ">=0.7.0"
|
||||
dm-env = "*"
|
||||
dm-tree = "!=0.1.2"
|
||||
glfw = "*"
|
||||
labmaze = "*"
|
||||
lxml = "*"
|
||||
mujoco = ">=2.3.7"
|
||||
numpy = ">=1.9.0"
|
||||
protobuf = ">=3.19.4"
|
||||
pyopengl = ">=3.1.4"
|
||||
pyparsing = ">=3.0.0"
|
||||
requests = "*"
|
||||
scipy = "*"
|
||||
setuptools = "!=50.0.0"
|
||||
tqdm = "*"
|
||||
|
||||
[package.extras]
|
||||
hdf5 = ["h5py"]
|
||||
|
||||
[[package]]
|
||||
name = "dm-env"
|
||||
version = "1.6"
|
||||
|
@ -584,43 +615,6 @@ files = [
|
|||
{file = "einops-0.7.0.tar.gz", hash = "sha256:b2b04ad6081a3b227080c9bf5e3ace7160357ff03043cd66cc5b2319eb7031d1"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "etils"
|
||||
version = "1.7.0"
|
||||
description = "Collection of common python utils"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
files = [
|
||||
{file = "etils-1.7.0-py3-none-any.whl", hash = "sha256:61af8f7c242171de15e22e5da02d527cb9e677d11f8bcafe18fcc3548eee3e60"},
|
||||
{file = "etils-1.7.0.tar.gz", hash = "sha256:97b68fd25e185683215286ef3a54e38199b6245f5fe8be6bedc1189be4256350"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
fsspec = {version = "*", optional = true, markers = "extra == \"epath\""}
|
||||
importlib_resources = {version = "*", optional = true, markers = "extra == \"epath\""}
|
||||
typing_extensions = {version = "*", optional = true, markers = "extra == \"epy\""}
|
||||
zipp = {version = "*", optional = true, markers = "extra == \"epath\""}
|
||||
|
||||
[package.extras]
|
||||
all = ["etils[array-types]", "etils[eapp]", "etils[ecolab]", "etils[edc]", "etils[enp]", "etils[epath-gcs]", "etils[epath-s3]", "etils[epath]", "etils[epy]", "etils[etqdm]", "etils[etree-dm]", "etils[etree-jax]", "etils[etree-tf]", "etils[etree]"]
|
||||
array-types = ["etils[enp]"]
|
||||
dev = ["chex", "dataclass_array", "optree", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-subtests", "pytest-xdist", "torch"]
|
||||
docs = ["etils[all,dev]", "sphinx-apitree[ext]"]
|
||||
eapp = ["absl-py", "etils[epy]", "simple_parsing"]
|
||||
ecolab = ["etils[enp]", "etils[epy]", "etils[etree]", "jupyter", "mediapy", "numpy", "packaging", "protobuf"]
|
||||
edc = ["etils[epy]"]
|
||||
enp = ["etils[epy]", "numpy"]
|
||||
epath = ["etils[epy]", "fsspec", "importlib_resources", "typing_extensions", "zipp"]
|
||||
epath-gcs = ["etils[epath]", "gcsfs"]
|
||||
epath-s3 = ["etils[epath]", "s3fs"]
|
||||
epy = ["typing_extensions"]
|
||||
etqdm = ["absl-py", "etils[epy]", "tqdm"]
|
||||
etree = ["etils[array-types]", "etils[enp]", "etils[epy]", "etils[etqdm]"]
|
||||
etree-dm = ["dm-tree", "etils[etree]"]
|
||||
etree-jax = ["etils[etree]", "jax[cpu]"]
|
||||
etree-tf = ["etils[etree]", "tensorflow"]
|
||||
lazy-imports = ["etils[ecolab]"]
|
||||
|
||||
[[package]]
|
||||
name = "exceptiongroup"
|
||||
version = "1.2.0"
|
||||
|
@ -846,13 +840,13 @@ numpy = ">=1.17.3"
|
|||
|
||||
[[package]]
|
||||
name = "huggingface-hub"
|
||||
version = "0.21.3"
|
||||
version = "0.21.4"
|
||||
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
|
||||
optional = false
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
{file = "huggingface_hub-0.21.3-py3-none-any.whl", hash = "sha256:b183144336fdf2810a8c109822e0bb6ef1fd61c65da6fb60e8c3f658b7144016"},
|
||||
{file = "huggingface_hub-0.21.3.tar.gz", hash = "sha256:26a15b604e4fc7bad37c467b76456543ec849386cbca9cd7e1e135f53e500423"},
|
||||
{file = "huggingface_hub-0.21.4-py3-none-any.whl", hash = "sha256:df37c2c37fc6c82163cdd8a67ede261687d80d1e262526d6c0ce73b6b3630a7b"},
|
||||
{file = "huggingface_hub-0.21.4.tar.gz", hash = "sha256:e1f4968c93726565a80edf6dc309763c7b546d0cfe79aa221206034d50155531"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -971,37 +965,22 @@ setuptools = "*"
|
|||
|
||||
[[package]]
|
||||
name = "importlib-metadata"
|
||||
version = "7.0.1"
|
||||
version = "7.0.2"
|
||||
description = "Read metadata from Python packages"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "importlib_metadata-7.0.1-py3-none-any.whl", hash = "sha256:4805911c3a4ec7c3966410053e9ec6a1fecd629117df5adee56dfc9432a1081e"},
|
||||
{file = "importlib_metadata-7.0.1.tar.gz", hash = "sha256:f238736bb06590ae52ac1fab06a3a9ef1d8dce2b7a35b5ab329371d6c8f5d2cc"},
|
||||
{file = "importlib_metadata-7.0.2-py3-none-any.whl", hash = "sha256:f4bc4c0c070c490abf4ce96d715f68e95923320370efb66143df00199bb6c100"},
|
||||
{file = "importlib_metadata-7.0.2.tar.gz", hash = "sha256:198f568f3230878cb1b44fbd7975f87906c22336dba2e4a7f05278c281fbd792"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
zipp = ">=0.5"
|
||||
|
||||
[package.extras]
|
||||
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"]
|
||||
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
|
||||
perf = ["ipython"]
|
||||
testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"]
|
||||
|
||||
[[package]]
|
||||
name = "importlib-resources"
|
||||
version = "6.1.2"
|
||||
description = "Read resources from Python packages"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "importlib_resources-6.1.2-py3-none-any.whl", hash = "sha256:9a0a862501dc38b68adebc82970140c9e4209fc99601782925178f8386339938"},
|
||||
{file = "importlib_resources-6.1.2.tar.gz", hash = "sha256:308abf8474e2dba5f867d279237cd4076482c3de7104a40b41426370e891549b"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"]
|
||||
testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "zipp (>=3.17)"]
|
||||
testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "iniconfig"
|
||||
|
@ -1031,6 +1010,50 @@ MarkupSafe = ">=2.0"
|
|||
[package.extras]
|
||||
i18n = ["Babel (>=2.7)"]
|
||||
|
||||
[[package]]
|
||||
name = "labmaze"
|
||||
version = "1.0.6"
|
||||
description = "LabMaze: DeepMind Lab's text maze generator."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "labmaze-1.0.6-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b2ddef976dfd8d992b19cfa6c633f2eba7576d759c2082da534e3f727479a84a"},
|
||||
{file = "labmaze-1.0.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:157efaa93228c8ccce5cae337902dd652093e0fba9d3a0f6506e4bee272bb66f"},
|
||||
{file = "labmaze-1.0.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b3ce98b9541c5fe6a306e411e7d018121dd646f2c9978d763fad86f9f30c5f57"},
|
||||
{file = "labmaze-1.0.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e6433bd49bc541791de8191040526fddfebb77151620eb04203453f43ee486a"},
|
||||
{file = "labmaze-1.0.6-cp310-cp310-win_amd64.whl", hash = "sha256:6a507fc35961f1b1479708e2716f65e0d0611cefb55f31a77be29ce2339b6fef"},
|
||||
{file = "labmaze-1.0.6-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:a0c2cb9dec971814ea9c5d7150af15fa3964482131fa969e0afb94bd224348af"},
|
||||
{file = "labmaze-1.0.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2c6ba9538d819543f4be448d36b4926a3881e53646a2b331ebb5a1f353047d05"},
|
||||
{file = "labmaze-1.0.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70635d1cdb0147a02efb6b3f607a52cdc51723bc3dcc42717a0d4ef55fa0a987"},
|
||||
{file = "labmaze-1.0.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff472793238bd9b6dabea8094594d6074ad3c111455de3afcae72f6c40c6817e"},
|
||||
{file = "labmaze-1.0.6-cp311-cp311-win_amd64.whl", hash = "sha256:2317e65e12fa3d1abecda7e0488dab15456cee8a2e717a586bfc8f02a91579e7"},
|
||||
{file = "labmaze-1.0.6-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:e36b6fadcd78f22057b597c1c77823e806a0987b3bdfbf850e14b6b5b502075e"},
|
||||
{file = "labmaze-1.0.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d1a4f8de29c2c3d7f14163759b69cd3f237093b85334c983619c1db5403a223b"},
|
||||
{file = "labmaze-1.0.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a394f8bb857fcaa2884b809d63e750841c2662a106cfe8c045f2112d201ac7d5"},
|
||||
{file = "labmaze-1.0.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d17abb69d4dfc56183afb5c317e8b2eaca0587abb3aabd2326efd3143c81f4e"},
|
||||
{file = "labmaze-1.0.6-cp312-cp312-win_amd64.whl", hash = "sha256:5af997598cc46b1929d1c5a1febc32fd56c75874fe481a2a5982c65cee8450c9"},
|
||||
{file = "labmaze-1.0.6-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:a4c5bc6e56baa55ce63b97569afec2f80cab0f6b952752a131e1f83eed190a53"},
|
||||
{file = "labmaze-1.0.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3955f24fe5f708e1e97495b4cfe284b70ae4fd51be5e17b75a6fc04ffbd67bca"},
|
||||
{file = "labmaze-1.0.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ed96ddc0bb8d66df36428c94db83949fd84a15867e8250763a4c5e3d82104c54"},
|
||||
{file = "labmaze-1.0.6-cp37-cp37m-win_amd64.whl", hash = "sha256:3bd0458a29e55aa09f146e28a168d2e00b8ccf19e2259a3f71154cfff3536b1d"},
|
||||
{file = "labmaze-1.0.6-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:33f5154edc83dff55a150e54b60c8582fdafc7ec45195049809cbcc01f5e8f34"},
|
||||
{file = "labmaze-1.0.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0971055ef2a5f7d8517fdc42b67c057093698f1eb911f46faa7018867b73fcc9"},
|
||||
{file = "labmaze-1.0.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de18d09680007302abf49111f3fe822d8435e4fbc4468b9ec07d50a78e267865"},
|
||||
{file = "labmaze-1.0.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f18126066db2218a52853c7dd490b4c3d8129fc22eb3a47eb23007524b911d53"},
|
||||
{file = "labmaze-1.0.6-cp38-cp38-win_amd64.whl", hash = "sha256:f9aef09a76877342bb4d634b7e05f43b038a49c4f34adfb8f1b8ac57c29472f2"},
|
||||
{file = "labmaze-1.0.6-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:5dd28899418f1b8b1c7d1e1b40a4593150a7cfa95ca91e23860b9785b82cc0ee"},
|
||||
{file = "labmaze-1.0.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:965569f37ee33090b4d4b3aa5aa7c9dcc4f62e2ae5d761e7f73ec76fc9d8aa96"},
|
||||
{file = "labmaze-1.0.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05eccfa98c0e781bc9f939076ae600b2e25ca736e123f2a530606aedec3b531c"},
|
||||
{file = "labmaze-1.0.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bee8c94e0fb3fc2d8180214947245c1d74a3489349a9da90b868296e77a521e9"},
|
||||
{file = "labmaze-1.0.6-cp39-cp39-win_amd64.whl", hash = "sha256:d486e9ca3a335ad628e3bd48a09c42f1aa5f51040952ef0fe32507afedcd694b"},
|
||||
{file = "labmaze-1.0.6.tar.gz", hash = "sha256:2e8de7094042a77d6972f1965cf5c9e8f971f1b34d225752f343190a825ebe73"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
absl-py = "*"
|
||||
numpy = ">=1.8.0"
|
||||
setuptools = "!=50.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "lazy-loader"
|
||||
version = "0.3"
|
||||
|
@ -1076,6 +1099,99 @@ files = [
|
|||
{file = "llvmlite-0.42.0.tar.gz", hash = "sha256:f92b09243c0cc3f457da8b983f67bd8e1295d0f5b3746c7a1861d7a99403854a"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lxml"
|
||||
version = "5.1.0"
|
||||
description = "Powerful and Pythonic XML processing library combining libxml2/libxslt with the ElementTree API."
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "lxml-5.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:704f5572ff473a5f897745abebc6df40f22d4133c1e0a1f124e4f2bd3330ff7e"},
|
||||
{file = "lxml-5.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9d3c0f8567ffe7502d969c2c1b809892dc793b5d0665f602aad19895f8d508da"},
|
||||
{file = "lxml-5.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5fcfbebdb0c5d8d18b84118842f31965d59ee3e66996ac842e21f957eb76138c"},
|
||||
{file = "lxml-5.1.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2f37c6d7106a9d6f0708d4e164b707037b7380fcd0b04c5bd9cae1fb46a856fb"},
|
||||
{file = "lxml-5.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2befa20a13f1a75c751f47e00929fb3433d67eb9923c2c0b364de449121f447c"},
|
||||
{file = "lxml-5.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22b7ee4c35f374e2c20337a95502057964d7e35b996b1c667b5c65c567d2252a"},
|
||||
{file = "lxml-5.1.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:bf8443781533b8d37b295016a4b53c1494fa9a03573c09ca5104550c138d5c05"},
|
||||
{file = "lxml-5.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:82bddf0e72cb2af3cbba7cec1d2fd11fda0de6be8f4492223d4a268713ef2147"},
|
||||
{file = "lxml-5.1.0-cp310-cp310-win32.whl", hash = "sha256:b66aa6357b265670bb574f050ffceefb98549c721cf28351b748be1ef9577d93"},
|
||||
{file = "lxml-5.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:4946e7f59b7b6a9e27bef34422f645e9a368cb2be11bf1ef3cafc39a1f6ba68d"},
|
||||
{file = "lxml-5.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:14deca1460b4b0f6b01f1ddc9557704e8b365f55c63070463f6c18619ebf964f"},
|
||||
{file = "lxml-5.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ed8c3d2cd329bf779b7ed38db176738f3f8be637bb395ce9629fc76f78afe3d4"},
|
||||
{file = "lxml-5.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:436a943c2900bb98123b06437cdd30580a61340fbdb7b28aaf345a459c19046a"},
|
||||
{file = "lxml-5.1.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:acb6b2f96f60f70e7f34efe0c3ea34ca63f19ca63ce90019c6cbca6b676e81fa"},
|
||||
{file = "lxml-5.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:af8920ce4a55ff41167ddbc20077f5698c2e710ad3353d32a07d3264f3a2021e"},
|
||||
{file = "lxml-5.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7cfced4a069003d8913408e10ca8ed092c49a7f6cefee9bb74b6b3e860683b45"},
|
||||
{file = "lxml-5.1.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:9e5ac3437746189a9b4121db2a7b86056ac8786b12e88838696899328fc44bb2"},
|
||||
{file = "lxml-5.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f4c9bda132ad108b387c33fabfea47866af87f4ea6ffb79418004f0521e63204"},
|
||||
{file = "lxml-5.1.0-cp311-cp311-win32.whl", hash = "sha256:bc64d1b1dab08f679fb89c368f4c05693f58a9faf744c4d390d7ed1d8223869b"},
|
||||
{file = "lxml-5.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:a5ab722ae5a873d8dcee1f5f45ddd93c34210aed44ff2dc643b5025981908cda"},
|
||||
{file = "lxml-5.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:9aa543980ab1fbf1720969af1d99095a548ea42e00361e727c58a40832439114"},
|
||||
{file = "lxml-5.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6f11b77ec0979f7e4dc5ae081325a2946f1fe424148d3945f943ceaede98adb8"},
|
||||
{file = "lxml-5.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a36c506e5f8aeb40680491d39ed94670487ce6614b9d27cabe45d94cd5d63e1e"},
|
||||
{file = "lxml-5.1.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f643ffd2669ffd4b5a3e9b41c909b72b2a1d5e4915da90a77e119b8d48ce867a"},
|
||||
{file = "lxml-5.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16dd953fb719f0ffc5bc067428fc9e88f599e15723a85618c45847c96f11f431"},
|
||||
{file = "lxml-5.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:16018f7099245157564d7148165132c70adb272fb5a17c048ba70d9cc542a1a1"},
|
||||
{file = "lxml-5.1.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:82cd34f1081ae4ea2ede3d52f71b7be313756e99b4b5f829f89b12da552d3aa3"},
|
||||
{file = "lxml-5.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:19a1bc898ae9f06bccb7c3e1dfd73897ecbbd2c96afe9095a6026016e5ca97b8"},
|
||||
{file = "lxml-5.1.0-cp312-cp312-win32.whl", hash = "sha256:13521a321a25c641b9ea127ef478b580b5ec82aa2e9fc076c86169d161798b01"},
|
||||
{file = "lxml-5.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:1ad17c20e3666c035db502c78b86e58ff6b5991906e55bdbef94977700c72623"},
|
||||
{file = "lxml-5.1.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:24ef5a4631c0b6cceaf2dbca21687e29725b7c4e171f33a8f8ce23c12558ded1"},
|
||||
{file = "lxml-5.1.0-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8d2900b7f5318bc7ad8631d3d40190b95ef2aa8cc59473b73b294e4a55e9f30f"},
|
||||
{file = "lxml-5.1.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:601f4a75797d7a770daed8b42b97cd1bb1ba18bd51a9382077a6a247a12aa38d"},
|
||||
{file = "lxml-5.1.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4b68c961b5cc402cbd99cca5eb2547e46ce77260eb705f4d117fd9c3f932b95"},
|
||||
{file = "lxml-5.1.0-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:afd825e30f8d1f521713a5669b63657bcfe5980a916c95855060048b88e1adb7"},
|
||||
{file = "lxml-5.1.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:262bc5f512a66b527d026518507e78c2f9c2bd9eb5c8aeeb9f0eb43fcb69dc67"},
|
||||
{file = "lxml-5.1.0-cp36-cp36m-win32.whl", hash = "sha256:e856c1c7255c739434489ec9c8aa9cdf5179785d10ff20add308b5d673bed5cd"},
|
||||
{file = "lxml-5.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:c7257171bb8d4432fe9d6fdde4d55fdbe663a63636a17f7f9aaba9bcb3153ad7"},
|
||||
{file = "lxml-5.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b9e240ae0ba96477682aa87899d94ddec1cc7926f9df29b1dd57b39e797d5ab5"},
|
||||
{file = "lxml-5.1.0-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a96f02ba1bcd330807fc060ed91d1f7a20853da6dd449e5da4b09bfcc08fdcf5"},
|
||||
{file = "lxml-5.1.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e3898ae2b58eeafedfe99e542a17859017d72d7f6a63de0f04f99c2cb125936"},
|
||||
{file = "lxml-5.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61c5a7edbd7c695e54fca029ceb351fc45cd8860119a0f83e48be44e1c464862"},
|
||||
{file = "lxml-5.1.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:3aeca824b38ca78d9ee2ab82bd9883083d0492d9d17df065ba3b94e88e4d7ee6"},
|
||||
{file = "lxml-5.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8f52fe6859b9db71ee609b0c0a70fea5f1e71c3462ecf144ca800d3f434f0764"},
|
||||
{file = "lxml-5.1.0-cp37-cp37m-win32.whl", hash = "sha256:d42e3a3fc18acc88b838efded0e6ec3edf3e328a58c68fbd36a7263a874906c8"},
|
||||
{file = "lxml-5.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:eac68f96539b32fce2c9b47eb7c25bb2582bdaf1bbb360d25f564ee9e04c542b"},
|
||||
{file = "lxml-5.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ae15347a88cf8af0949a9872b57a320d2605ae069bcdf047677318bc0bba45b1"},
|
||||
{file = "lxml-5.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c26aab6ea9c54d3bed716b8851c8bfc40cb249b8e9880e250d1eddde9f709bf5"},
|
||||
{file = "lxml-5.1.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:342e95bddec3a698ac24378d61996b3ee5ba9acfeb253986002ac53c9a5f6f84"},
|
||||
{file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:725e171e0b99a66ec8605ac77fa12239dbe061482ac854d25720e2294652eeaa"},
|
||||
{file = "lxml-5.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d184e0d5c918cff04cdde9dbdf9600e960161d773666958c9d7b565ccc60c45"},
|
||||
{file = "lxml-5.1.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:98f3f020a2b736566c707c8e034945c02aa94e124c24f77ca097c446f81b01f1"},
|
||||
{file = "lxml-5.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6d48fc57e7c1e3df57be5ae8614bab6d4e7b60f65c5457915c26892c41afc59e"},
|
||||
{file = "lxml-5.1.0-cp38-cp38-win32.whl", hash = "sha256:7ec465e6549ed97e9f1e5ed51c657c9ede767bc1c11552f7f4d022c4df4a977a"},
|
||||
{file = "lxml-5.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:b21b4031b53d25b0858d4e124f2f9131ffc1530431c6d1321805c90da78388d1"},
|
||||
{file = "lxml-5.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:52427a7eadc98f9e62cb1368a5079ae826f94f05755d2d567d93ee1bc3ceb354"},
|
||||
{file = "lxml-5.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6a2a2c724d97c1eb8cf966b16ca2915566a4904b9aad2ed9a09c748ffe14f969"},
|
||||
{file = "lxml-5.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:843b9c835580d52828d8f69ea4302537337a21e6b4f1ec711a52241ba4a824f3"},
|
||||
{file = "lxml-5.1.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9b99f564659cfa704a2dd82d0684207b1aadf7d02d33e54845f9fc78e06b7581"},
|
||||
{file = "lxml-5.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f8b0c78e7aac24979ef09b7f50da871c2de2def043d468c4b41f512d831e912"},
|
||||
{file = "lxml-5.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9bcf86dfc8ff3e992fed847c077bd875d9e0ba2fa25d859c3a0f0f76f07f0c8d"},
|
||||
{file = "lxml-5.1.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:49a9b4af45e8b925e1cd6f3b15bbba2c81e7dba6dce170c677c9cda547411e14"},
|
||||
{file = "lxml-5.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:280f3edf15c2a967d923bcfb1f8f15337ad36f93525828b40a0f9d6c2ad24890"},
|
||||
{file = "lxml-5.1.0-cp39-cp39-win32.whl", hash = "sha256:ed7326563024b6e91fef6b6c7a1a2ff0a71b97793ac33dbbcf38f6005e51ff6e"},
|
||||
{file = "lxml-5.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:8d7b4beebb178e9183138f552238f7e6613162a42164233e2bda00cb3afac58f"},
|
||||
{file = "lxml-5.1.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:9bd0ae7cc2b85320abd5e0abad5ccee5564ed5f0cc90245d2f9a8ef330a8deae"},
|
||||
{file = "lxml-5.1.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8c1d679df4361408b628f42b26a5d62bd3e9ba7f0c0e7969f925021554755aa"},
|
||||
{file = "lxml-5.1.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:2ad3a8ce9e8a767131061a22cd28fdffa3cd2dc193f399ff7b81777f3520e372"},
|
||||
{file = "lxml-5.1.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:304128394c9c22b6569eba2a6d98392b56fbdfbad58f83ea702530be80d0f9df"},
|
||||
{file = "lxml-5.1.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d74fcaf87132ffc0447b3c685a9f862ffb5b43e70ea6beec2fb8057d5d2a1fea"},
|
||||
{file = "lxml-5.1.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:8cf5877f7ed384dabfdcc37922c3191bf27e55b498fecece9fd5c2c7aaa34c33"},
|
||||
{file = "lxml-5.1.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:877efb968c3d7eb2dad540b6cabf2f1d3c0fbf4b2d309a3c141f79c7e0061324"},
|
||||
{file = "lxml-5.1.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f14a4fb1c1c402a22e6a341a24c1341b4a3def81b41cd354386dcb795f83897"},
|
||||
{file = "lxml-5.1.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:25663d6e99659544ee8fe1b89b1a8c0aaa5e34b103fab124b17fa958c4a324a6"},
|
||||
{file = "lxml-5.1.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:8b9f19df998761babaa7f09e6bc169294eefafd6149aaa272081cbddc7ba4ca3"},
|
||||
{file = "lxml-5.1.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e53d7e6a98b64fe54775d23a7c669763451340c3d44ad5e3a3b48a1efbdc96f"},
|
||||
{file = "lxml-5.1.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:c3cd1fc1dc7c376c54440aeaaa0dcc803d2126732ff5c6b68ccd619f2e64be4f"},
|
||||
{file = "lxml-5.1.0.tar.gz", hash = "sha256:3eea6ed6e6c918e468e693c41ef07f3c3acc310b70ddd9cc72d9ef84bc9564ca"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
cssselect = ["cssselect (>=0.7)"]
|
||||
html5 = ["html5lib"]
|
||||
htmlsoup = ["BeautifulSoup4"]
|
||||
source = ["Cython (>=3.0.7)"]
|
||||
|
||||
[[package]]
|
||||
name = "markupsafe"
|
||||
version = "2.1.5"
|
||||
|
@ -1188,42 +1304,40 @@ tests = ["pytest (>=4.6)"]
|
|||
|
||||
[[package]]
|
||||
name = "mujoco"
|
||||
version = "3.1.2"
|
||||
version = "2.3.7"
|
||||
description = "MuJoCo Physics Simulator"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "mujoco-3.1.2-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:fe6b3542695a5363f348ee45625b3492734f29cdc9f493ca25eae719f974370e"},
|
||||
{file = "mujoco-3.1.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f07e2d1f01f1401f1a503187016f8c017d9402618c659e1482243640a1e11288"},
|
||||
{file = "mujoco-3.1.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:93863eccc9d77d96ce62dda2a6f61cbd880379e8d774f802568d64b9613fce39"},
|
||||
{file = "mujoco-3.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3586c642390c16fef58b01a86071cec6814c471586e2f4115c3733c4aec64fb7"},
|
||||
{file = "mujoco-3.1.2-cp310-cp310-win_amd64.whl", hash = "sha256:0da77394c664945b78f199c627b609fe091ec0c4641b9d8f713637344a17821a"},
|
||||
{file = "mujoco-3.1.2-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:b6f12904d0478c191e4770ecf9006e20953f0488a2411a8ddc62592721c136dc"},
|
||||
{file = "mujoco-3.1.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f69b8d42b50c10f8d12df4948fc9d4dd6706841e7b163c1d7ce83448965acb1c"},
|
||||
{file = "mujoco-3.1.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:10119e39b1f45fb76b18bea242fea1d6ccf4b2285f8bd5e2cb1e2cbdeb69bdcd"},
|
||||
{file = "mujoco-3.1.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a65868506dd45dddfe7be84857e57b49bc102334fc0439aa848a4d4d285d89b"},
|
||||
{file = "mujoco-3.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:92bc73972e39539f23a05bb411c45f9be17191fe01171ac15ffafed381ee4366"},
|
||||
{file = "mujoco-3.1.2-cp312-cp312-macosx_10_16_x86_64.whl", hash = "sha256:835d6b64ca4dc2f6a83291275fd48bd83edc888039d247958bf5b2c759db4340"},
|
||||
{file = "mujoco-3.1.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ce94ca3cf14fc519981674c5b85f1055356dcdcd63bbc0ec6c340084438f27f"},
|
||||
{file = "mujoco-3.1.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:250d9de4bd0d31fa4165faf01a1f838c429434f1263faacd95b977580f24eae7"},
|
||||
{file = "mujoco-3.1.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ea009d10bbf0aba9bc835f051d25f07a2c3edbaa06627ac2348766a1f3760b9"},
|
||||
{file = "mujoco-3.1.2-cp312-cp312-win_amd64.whl", hash = "sha256:a0460d2ebdad4926f48b8c774da473e011c3b3afd0ccb6b6be1087b788c34011"},
|
||||
{file = "mujoco-3.1.2-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:4ca7cae89e258a338e02229edcf8f177b459ac5e9f859ffffa07fc2c9fcfb6aa"},
|
||||
{file = "mujoco-3.1.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:33b4fe9b5f891b29ef0fc2b0b975bc3a8a4b87774eecaf4364a83ddc6a7762ba"},
|
||||
{file = "mujoco-3.1.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ed230980f33bafaf1fa8b32ef25b82b069a245de15ee6ce7127e7e054cfad16"},
|
||||
{file = "mujoco-3.1.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41cc610ac40f325c9d49d9885ac6cb61822ed938f6c23cb183b261a7a28472ca"},
|
||||
{file = "mujoco-3.1.2-cp38-cp38-win_amd64.whl", hash = "sha256:90a172b904a6ca8e6a1be80ab7c393aaff7592843a2a6853a4f97a9204031c41"},
|
||||
{file = "mujoco-3.1.2-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:93201291a0c5b573b4cbb19a6b08c99673f9fba167f402174eae5ffa23066d24"},
|
||||
{file = "mujoco-3.1.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0398985bb28c2686cdeeaf4ded46e602a49ec12115ac77474144ca940e5261c5"},
|
||||
{file = "mujoco-3.1.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2e76b5cb07ab3088c81966ac774d573df027fa5f4e78c20953a547528a2a698"},
|
||||
{file = "mujoco-3.1.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd5c3f4ae858e812cb3f03332693bcdc343b2bce55b164523acf52dea2736c9e"},
|
||||
{file = "mujoco-3.1.2-cp39-cp39-win_amd64.whl", hash = "sha256:ca25ff2646b06609526ef8681c0e123cd854a53c9ff23cb91dd5058a2794dab4"},
|
||||
{file = "mujoco-3.1.2.tar.gz", hash = "sha256:53530bc1a91903f3fd4b1e99818cc38fbd9911700db29b2c9fc839f23bfacbb8"},
|
||||
{file = "mujoco-2.3.7-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:e8714a5ff6a1561b364b7b4648d4c0c8d13e751874cf7401c309b9d23fa9598b"},
|
||||
{file = "mujoco-2.3.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a934315f858a4e0c4b90a682fde519471cfdd7baa64435179da8cd20d4ae3f99"},
|
||||
{file = "mujoco-2.3.7-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:36513024330f88b5f9a43558efef5692b33599bffd5141029b690a27918ffcbe"},
|
||||
{file = "mujoco-2.3.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6d4eede8ba8210fbd3d3cd1dbf69e24dd1541aa74c5af5b8adbbbf65504b6dba"},
|
||||
{file = "mujoco-2.3.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab85fafc9d5a091c712947573b7e694512d283876bf7f33ae3f8daad3a20c0db"},
|
||||
{file = "mujoco-2.3.7-cp310-cp310-win_amd64.whl", hash = "sha256:f8b7e13fef8c813d91b78f975ed0815157692777907ffa4b4be53a4edb75019b"},
|
||||
{file = "mujoco-2.3.7-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:779520216f72a8e370e3f0cdd71b45c3b7384c63331a3189194c930a3e7cff5c"},
|
||||
{file = "mujoco-2.3.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9d4018053879016282d27ab7a91e292c72d44efb5a88553feacfe5b843dde103"},
|
||||
{file = "mujoco-2.3.7-cp311-cp311-macosx_11_0_x86_64.whl", hash = "sha256:3149b16b8122ee62642474bfd2871064e8edc40235471cf5d84be3569afc0312"},
|
||||
{file = "mujoco-2.3.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c08660a8d52ef3efde76095f0991e807703a950c1e882d2bcd984b9a846626f7"},
|
||||
{file = "mujoco-2.3.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:426af8965f8636d94a0f75740c3024a62b3e585020ee817ef5208ec844a1ad94"},
|
||||
{file = "mujoco-2.3.7-cp311-cp311-win_amd64.whl", hash = "sha256:215415a8e98a4b50625beae859079d5e0810b2039e50420f0ba81763c34abb59"},
|
||||
{file = "mujoco-2.3.7-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:8b78d14f4c60cea3c58e046bd4de453fb5b9b33aca6a25fc91d39a53f3a5342a"},
|
||||
{file = "mujoco-2.3.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5c6f5a51d6f537a4bf294cf73816f3a6384573f8f10a5452b044df2771412a96"},
|
||||
{file = "mujoco-2.3.7-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:ea8911e6047f92d7d775701f37e4c093971b6def3160f01d0b6926e29a7e962e"},
|
||||
{file = "mujoco-2.3.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7473a3de4dd1a8762d569ffb139196b4c5e7eca27d256df97b6cd4c66d2a09b2"},
|
||||
{file = "mujoco-2.3.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40e7e2d8f93d2495ec74efec84e5118ecc6e1d85157a844789c73c9ac9a4e28e"},
|
||||
{file = "mujoco-2.3.7-cp38-cp38-win_amd64.whl", hash = "sha256:720bc228a2023b3b0ed6af78f5b0f8ea36867be321d473321555c57dbf6e4e5b"},
|
||||
{file = "mujoco-2.3.7-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:855e79686366442aa410246043b44f7d842d3900d68fe7e37feb42147db9d707"},
|
||||
{file = "mujoco-2.3.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:98947f4a742d34d36f3c3f83e9167025bb0414bbaa4bd859b0673bdab9959963"},
|
||||
{file = "mujoco-2.3.7-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:d42818f2ee5d1632dbce31d136ed5ff868db54b04e4e9aca0c5a3ac329f8a90f"},
|
||||
{file = "mujoco-2.3.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9237e1ba14bced9449c31199e6d5be49547f3a4c99bc83b196af7ca45fd73b83"},
|
||||
{file = "mujoco-2.3.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b728ea638245b150e2650c5433e6952e0ed3798c63e47e264574270caea2a3"},
|
||||
{file = "mujoco-2.3.7-cp39-cp39-win_amd64.whl", hash = "sha256:9c721a5042b99d948d5f0296a534bcce3f142c777c4d7642f503a539513f3912"},
|
||||
{file = "mujoco-2.3.7.tar.gz", hash = "sha256:422041f1ce37c6d151fbced1048df626837e94fe3cd9f813585907046336a7d0"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
absl-py = "*"
|
||||
etils = {version = "*", extras = ["epath"]}
|
||||
glfw = "*"
|
||||
numpy = "*"
|
||||
pyopengl = "*"
|
||||
|
@ -1519,13 +1633,13 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "nvidia-nvjitlink-cu12"
|
||||
version = "12.3.101"
|
||||
version = "12.4.99"
|
||||
description = "Nvidia JIT LTO Library"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
files = [
|
||||
{file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:64335a8088e2b9d196ae8665430bc6a2b7e6ef2eb877a9c735c804bd4ff6467c"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-win_amd64.whl", hash = "sha256:1b2e317e437433753530792f13eece58f0aec21a2b05903be7bffe58a606cbd1"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_x86_64.whl", hash = "sha256:c6428836d20fe7e327191c175791d38570e10762edc588fb46749217cd444c74"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-win_amd64.whl", hash = "sha256:991905ffa2144cb603d8ca7962d75c35334ae82bf92820b6ba78157277da1ad2"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1580,13 +1694,13 @@ numpy = [
|
|||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
version = "23.2"
|
||||
version = "24.0"
|
||||
description = "Core utilities for Python packages"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"},
|
||||
{file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"},
|
||||
{file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"},
|
||||
{file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -2016,6 +2130,20 @@ files = [
|
|||
{file = "PyOpenGL-3.1.7.tar.gz", hash = "sha256:eef31a3888e6984fd4d8e6c9961b184c9813ca82604d37fe3da80eb000a76c86"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyparsing"
|
||||
version = "3.1.2"
|
||||
description = "pyparsing module - Classes and methods to define and execute parsing grammars"
|
||||
optional = false
|
||||
python-versions = ">=3.6.8"
|
||||
files = [
|
||||
{file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"},
|
||||
{file = "pyparsing-3.1.2.tar.gz", hash = "sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
diagrams = ["jinja2", "railroad-diagrams"]
|
||||
|
||||
[[package]]
|
||||
name = "pysocks"
|
||||
version = "1.7.1"
|
||||
|
@ -2030,13 +2158,13 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "8.1.0"
|
||||
version = "8.1.1"
|
||||
description = "pytest: simple powerful testing with Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pytest-8.1.0-py3-none-any.whl", hash = "sha256:ee32db7af8de4629a455806befa90559f307424c07b8413ccfc30bf5b221dd7e"},
|
||||
{file = "pytest-8.1.0.tar.gz", hash = "sha256:f8fa04ab8f98d185113ae60ea6d79c22f8143b14bc1caeced44a0ab844928323"},
|
||||
{file = "pytest-8.1.1-py3-none-any.whl", hash = "sha256:2a8386cfc11fa9d2c50ee7b2a57e7d898ef90470a7a34c4b949ff59662bb78b7"},
|
||||
{file = "pytest-8.1.1.tar.gz", hash = "sha256:ac978141a75948948817d360297b7aae0fcb9d6ff6bc9ec6d514b85d5a65c044"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -2052,13 +2180,13 @@ testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygm
|
|||
|
||||
[[package]]
|
||||
name = "python-dateutil"
|
||||
version = "2.8.2"
|
||||
version = "2.9.0.post0"
|
||||
description = "Extensions to the standard Python datetime module"
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
|
||||
files = [
|
||||
{file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
|
||||
{file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
|
||||
{file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"},
|
||||
{file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -2483,13 +2611,13 @@ test = ["asv", "gmpy2", "hypothesis", "mpmath", "pooch", "pytest", "pytest-cov",
|
|||
|
||||
[[package]]
|
||||
name = "sentry-sdk"
|
||||
version = "1.40.6"
|
||||
version = "1.41.0"
|
||||
description = "Python client for Sentry (https://sentry.io)"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "sentry-sdk-1.40.6.tar.gz", hash = "sha256:f143f3fb4bb57c90abef6e2ad06b5f6f02b2ca13e4060ec5c0549c7a9ccce3fa"},
|
||||
{file = "sentry_sdk-1.40.6-py2.py3-none-any.whl", hash = "sha256:becda09660df63e55f307570e9817c664392655a7328bbc414b507e9cb874c67"},
|
||||
{file = "sentry-sdk-1.41.0.tar.gz", hash = "sha256:4f2d6c43c07925d8cd10dfbd0970ea7cb784f70e79523cca9dbcd72df38e5a46"},
|
||||
{file = "sentry_sdk-1.41.0-py2.py3-none-any.whl", hash = "sha256:be4f8f4b29a80b6a3b71f0f31487beb9e296391da20af8504498a328befed53f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -2515,7 +2643,7 @@ huey = ["huey (>=2)"]
|
|||
loguru = ["loguru (>=0.5)"]
|
||||
opentelemetry = ["opentelemetry-distro (>=0.35b0)"]
|
||||
opentelemetry-experimental = ["opentelemetry-distro (>=0.40b0,<1.0)", "opentelemetry-instrumentation-aiohttp-client (>=0.40b0,<1.0)", "opentelemetry-instrumentation-django (>=0.40b0,<1.0)", "opentelemetry-instrumentation-fastapi (>=0.40b0,<1.0)", "opentelemetry-instrumentation-flask (>=0.40b0,<1.0)", "opentelemetry-instrumentation-requests (>=0.40b0,<1.0)", "opentelemetry-instrumentation-sqlite3 (>=0.40b0,<1.0)", "opentelemetry-instrumentation-urllib (>=0.40b0,<1.0)"]
|
||||
pure-eval = ["asttokens", "executing", "pure_eval"]
|
||||
pure-eval = ["asttokens", "executing", "pure-eval"]
|
||||
pymongo = ["pymongo (>=3.1)"]
|
||||
pyspark = ["pyspark (>=2.4.4)"]
|
||||
quart = ["blinker (>=1.1)", "quart (>=0.16.1)"]
|
||||
|
@ -2769,7 +2897,7 @@ tests = ["pytest", "pytest-benchmark", "pytest-instafail", "pytest-rerunfailures
|
|||
type = "git"
|
||||
url = "https://github.com/pytorch/tensordict"
|
||||
reference = "HEAD"
|
||||
resolved_reference = "551331d83e2979dd4505db1f49895740e6e5c95f"
|
||||
resolved_reference = "ed22554d6860731610df784b2f5d09f31d3dbc7a"
|
||||
|
||||
[[package]]
|
||||
name = "termcolor"
|
||||
|
@ -2888,13 +3016,13 @@ tensordict = ">=0.4.0"
|
|||
torch = ">=2.1.0"
|
||||
|
||||
[package.extras]
|
||||
all = ["ale-py", "atari-py", "dm-control", "git", "gym", "gym[accept-rom-license]", "gymnasium", "h5py", "huggingface-hub", "hydra-core (>=1.1)", "hydra-submitit-launcher", "minari", "moviepy", "mujoco", "pandas", "pettingzoo (>=1.24.1)", "pillow", "pygame", "pytest", "pytest-instafail", "pyyaml", "requests", "scikit-learn", "scipy", "tensorboard", "torchsnapshot", "torchvision", "tqdm", "vmas (>=1.2.10)", "wandb"]
|
||||
all = ["ale-py", "atari-py", "dm_control", "git", "gym", "gym[accept-rom-license]", "gymnasium", "h5py", "huggingface_hub", "hydra-core (>=1.1)", "hydra-submitit-launcher", "minari", "moviepy", "mujoco", "pandas", "pettingzoo (>=1.24.1)", "pillow", "pygame", "pytest", "pytest-instafail", "pyyaml", "requests", "scikit-learn", "scipy", "tensorboard", "torchsnapshot", "torchvision", "tqdm", "vmas (>=1.2.10)", "wandb"]
|
||||
atari = ["ale-py", "atari-py", "gym", "gym[accept-rom-license]", "pygame"]
|
||||
checkpointing = ["torchsnapshot"]
|
||||
dm-control = ["dm-control"]
|
||||
dm-control = ["dm_control"]
|
||||
gym-continuous = ["gymnasium", "mujoco"]
|
||||
marl = ["pettingzoo (>=1.24.1)", "vmas (>=1.2.10)"]
|
||||
offline-data = ["h5py", "huggingface-hub", "minari", "pandas", "pillow", "requests", "scikit-learn", "torchvision", "tqdm"]
|
||||
offline-data = ["h5py", "huggingface_hub", "minari", "pandas", "pillow", "requests", "scikit-learn", "torchvision", "tqdm"]
|
||||
rendering = ["moviepy"]
|
||||
tests = ["pytest", "pytest-instafail", "pyyaml", "scipy"]
|
||||
utils = ["git", "hydra-core (>=1.1)", "hydra-submitit-launcher", "tensorboard", "tqdm", "wandb"]
|
||||
|
@ -3051,13 +3179,13 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess
|
|||
|
||||
[[package]]
|
||||
name = "wandb"
|
||||
version = "0.16.3"
|
||||
version = "0.16.4"
|
||||
description = "A CLI and library for interacting with the Weights & Biases API."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "wandb-0.16.3-py3-none-any.whl", hash = "sha256:b8907ddd775c27dc6c12687386a86b5d6acf291060f9ae680bbc61cc8fc03237"},
|
||||
{file = "wandb-0.16.3.tar.gz", hash = "sha256:d789acda32053b18b7a160d0595837e45a3c8a79d25e1fe1f051875303f480ec"},
|
||||
{file = "wandb-0.16.4-py3-none-any.whl", hash = "sha256:bb9eb5aa2c2c85e11c76040c4271366f54d4975167aa6320ba86c3f2d97fe5fa"},
|
||||
{file = "wandb-0.16.4.tar.gz", hash = "sha256:8752c67d1347a4c29777e64dc1e1a742a66c5ecde03aebadf2b0d62183fa307c"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -3078,8 +3206,9 @@ async = ["httpx (>=0.23.0)"]
|
|||
aws = ["boto3"]
|
||||
azure = ["azure-identity", "azure-storage-blob"]
|
||||
gcp = ["google-cloud-storage"]
|
||||
importers = ["filelock", "mlflow", "polars", "rich", "tenacity"]
|
||||
kubeflow = ["google-cloud-storage", "kubernetes", "minio", "sh"]
|
||||
launch = ["PyYAML (>=6.0.0)", "awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "pydantic", "typing-extensions"]
|
||||
launch = ["PyYAML (>=6.0.0)", "awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "pydantic", "tomli", "typing-extensions"]
|
||||
media = ["bokeh", "moviepy", "numpy", "pillow", "plotly (>=5.18.0)", "rdkit-pypi", "soundfile"]
|
||||
models = ["cloudpickle"]
|
||||
perf = ["orjson"]
|
||||
|
@ -3088,13 +3217,13 @@ sweeps = ["sweeps (>=0.2.0)"]
|
|||
|
||||
[[package]]
|
||||
name = "zarr"
|
||||
version = "2.17.0"
|
||||
version = "2.17.1"
|
||||
description = "An implementation of chunked, compressed, N-dimensional arrays for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "zarr-2.17.0-py3-none-any.whl", hash = "sha256:d287cb61019c4a0a0f386f76eeaa7f0b1160b1cb90cf96173a4b6cbc135df6e1"},
|
||||
{file = "zarr-2.17.0.tar.gz", hash = "sha256:6390a2b8af31babaab4c963efc45bf1da7f9500c9aafac193f84cf019a7c66b0"},
|
||||
{file = "zarr-2.17.1-py3-none-any.whl", hash = "sha256:e25df2741a6e92645f3890f30f3136d5b57a0f8f831094b024bbcab5f2797bc7"},
|
||||
{file = "zarr-2.17.1.tar.gz", hash = "sha256:564b3aa072122546fe69a0fa21736f466b20fad41754334b62619f088ce46261"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -3125,4 +3254,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "c4d83579aed1c8c2e54cad7c8ec81b95a09ab8faff74fc9a4cb20bd00e4ddec6"
|
||||
content-hash = "3d82309a7b2388d774b56ceb6f6906ef0732d8cedda0d76cc84a30e239949be8"
|
||||
|
|
|
@ -42,13 +42,14 @@ mpmath = "^1.3.0"
|
|||
torch = "^2.2.1"
|
||||
tensordict = {git = "https://github.com/pytorch/tensordict"}
|
||||
torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"}
|
||||
mujoco = "^3.1.2"
|
||||
mujoco = "2.3.7"
|
||||
mujoco-py = "^2.1.2.14"
|
||||
gym = "^0.26.2"
|
||||
opencv-python = "^4.9.0.80"
|
||||
diffusers = "^0.26.3"
|
||||
torchvision = "^0.17.1"
|
||||
h5py = "^3.10.0"
|
||||
dm-control = "1.0.14"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
|
|
@ -17,6 +17,7 @@ apptainer exec --nv \
|
|||
~/apptainer/nvidia_cuda:12.2.2-devel-ubuntu22.04.sif $SHELL
|
||||
|
||||
source ~/.bashrc
|
||||
conda activate fowm
|
||||
#conda activate fowm
|
||||
conda activate lerobot
|
||||
|
||||
srun $CMD
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d789deddb081a9f4b626342391de8f48949d38fb5fdead87b5c0737b46c0877a
|
||||
size 2800
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5
|
||||
size 400
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c202d9cfc7858fd49d522047e16948359bbbb2eda2d3825d552e45a78d5f8585
|
||||
size 400
|
|
@ -0,0 +1 @@
|
|||
{"action": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51
|
||||
size 50
|
|
@ -0,0 +1 @@
|
|||
{"done": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
|
|
@ -0,0 +1 @@
|
|||
{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5c632e3cb06be729e5d673e3ecca1d6f6527b0f48cfe3dc03d7eea4f9eb3bbd7
|
||||
size 46080000
|
|
@ -0,0 +1 @@
|
|||
{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e231f2e07e1cd030137ea2e938b570b112db2c694c6d21b37ceb8f8559e19088
|
||||
size 2800
|
|
@ -0,0 +1 @@
|
|||
{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a1ba64c89f4fcf9135fe34c26abf582dd5f0d573506db5c96af3ffe40a52c818
|
||||
size 46080000
|
|
@ -0,0 +1 @@
|
|||
{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:85405686bc065c6ab6c915907920a0391a57cf097b74de058a8c30be0548ade5
|
||||
size 2800
|
Binary file not shown.
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1f5fe053b760e8471885b82c10f4a6ea40874098036337ae5cc300c4775546be
|
||||
size 2800
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5
|
||||
size 400
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c202d9cfc7858fd49d522047e16948359bbbb2eda2d3825d552e45a78d5f8585
|
||||
size 400
|
|
@ -0,0 +1 @@
|
|||
{"action": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51
|
||||
size 50
|
|
@ -0,0 +1 @@
|
|||
{"done": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
|
|
@ -0,0 +1 @@
|
|||
{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:daed2bb10498ba2557983d0d7e89399882fea7585e7ceff910e23c621bfdbf88
|
||||
size 46080000
|
|
@ -0,0 +1 @@
|
|||
{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bbad0302af70112ee312efe0eb0f44a2f1c8f6c5ef82ea4fb34625cdafbef057
|
||||
size 2800
|
|
@ -0,0 +1 @@
|
|||
{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:aba55ebb9dd004bf68444b9ebf024ed7713436099c06a0b8e541100ecbc69290
|
||||
size 46080000
|
|
@ -0,0 +1 @@
|
|||
{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:dd4e7e14abf57561ca9839c910581266be90956e41bfb3bb21362ea0c321e77d
|
||||
size 2800
|
Binary file not shown.
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:14fed0eed3d529a8ac0dd25a6d41585020772d02f9137fc9d604713b2f0f7076
|
||||
size 2800
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5
|
||||
size 400
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c202d9cfc7858fd49d522047e16948359bbbb2eda2d3825d552e45a78d5f8585
|
||||
size 400
|
|
@ -0,0 +1 @@
|
|||
{"action": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51
|
||||
size 50
|
|
@ -0,0 +1 @@
|
|||
{"done": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
|
|
@ -0,0 +1 @@
|
|||
{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2f713ea7fc19e592ea409a5e0bdfde403e5b86f834cbabe3463b791e8437fafc
|
||||
size 46080000
|
|
@ -0,0 +1 @@
|
|||
{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8c103e2c9d63c9f7cf9645bd24d9a2c4e8e08825dc75e230ebc793b8f9c213b0
|
||||
size 2800
|
|
@ -0,0 +1 @@
|
|||
{"top": {"device": "cpu", "shape": [50, 3, 480, 640], "dtype": "torch.uint8"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7dbf4aa01b184d0eaa21ea999078d7cff86e1ca484a109614176fdc49f1ee05c
|
||||
size 46080000
|
|
@ -0,0 +1 @@
|
|||
{"state": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4fa0b9c870d4615037b6fee9e9e85e54d84352e173f2c7c1035232272fe2a3dd
|
||||
size 2800
|
Binary file not shown.
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c0e199a82e2b7462e84406dbced5448a99f1dad9ce172771dfc3feb4b8597115
|
||||
size 2800
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue