Include `observation.environment_state` with keypoints in PushT dataset (#303)

Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
Alexander Soare 2024-07-09 08:27:40 +01:00 committed by GitHub
parent 7bd5ab16d1
commit a4d77b99f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 739 additions and 609 deletions

View File

@ -70,6 +70,8 @@ available_datasets_per_env = {
"lerobot/aloha_sim_transfer_cube_human_image", "lerobot/aloha_sim_transfer_cube_human_image",
"lerobot/aloha_sim_transfer_cube_scripted_image", "lerobot/aloha_sim_transfer_cube_scripted_image",
], ],
# TODO(alexander-soare): Add "lerobot/pusht_keypoints". Right now we can't because this is too tightly
# coupled with tests.
"pusht": ["lerobot/pusht", "lerobot/pusht_image"], "pusht": ["lerobot/pusht", "lerobot/pusht_image"],
"xarm": [ "xarm": [
"lerobot/xarm_lift_medium", "lerobot/xarm_lift_medium",

View File

@ -36,7 +36,7 @@ from lerobot.common.datasets.utils import (
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
CODEBASE_VERSION = "v1.4" CODEBASE_VERSION = "v1.5"
class LeRobotDataset(torch.utils.data.Dataset): class LeRobotDataset(torch.utils.data.Dataset):

View File

@ -54,7 +54,14 @@ def check_format(raw_dir):
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets) assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None): def load_from_raw(
raw_dir: Path,
videos_dir: Path,
fps: int,
video: bool,
episodes: list[int] | None = None,
keypoints_instead_of_image: bool = False,
):
try: try:
import pymunk import pymunk
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
@ -105,6 +112,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
assert (episode_ids[from_idx:to_idx] == ep_idx).all() assert (episode_ids[from_idx:to_idx] == ep_idx).all()
# get image # get image
if not keypoints_instead_of_image:
image = imgs[from_idx:to_idx] image = imgs[from_idx:to_idx]
assert image.min() >= 0.0 assert image.min() >= 0.0
assert image.max() <= 255.0 assert image.max() <= 255.0
@ -116,9 +124,11 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
block_pos = state[:, 2:4] block_pos = state[:, 2:4]
block_angle = state[:, 4] block_angle = state[:, 4]
# get reward, success, done # get reward, success, done, and (maybe) keypoints
reward = torch.zeros(num_frames) reward = torch.zeros(num_frames)
success = torch.zeros(num_frames, dtype=torch.bool) success = torch.zeros(num_frames, dtype=torch.bool)
if keypoints_instead_of_image:
keypoints = torch.zeros(num_frames, 16) # 8 keypoints each with 2 coords
done = torch.zeros(num_frames, dtype=torch.bool) done = torch.zeros(num_frames, dtype=torch.bool)
for i in range(num_frames): for i in range(num_frames):
space = pymunk.Space() space = pymunk.Space()
@ -134,7 +144,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
] ]
space.add(*walls) space.add(*walls)
block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item()) block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes) goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
block_geom = pymunk_to_shapely(block_body, block_body.shapes) block_geom = pymunk_to_shapely(block_body, block_body.shapes)
intersection_area = goal_geom.intersection(block_geom).area intersection_area = goal_geom.intersection(block_geom).area
@ -142,12 +152,15 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
coverage = intersection_area / goal_area coverage = intersection_area / goal_area
reward[i] = np.clip(coverage / success_threshold, 0, 1) reward[i] = np.clip(coverage / success_threshold, 0, 1)
success[i] = coverage > success_threshold success[i] = coverage > success_threshold
if keypoints_instead_of_image:
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
# last step of demonstration is considered done # last step of demonstration is considered done
done[-1] = True done[-1] = True
ep_dict = {} ep_dict = {}
if not keypoints_instead_of_image:
imgs_array = [x.numpy() for x in image] imgs_array = [x.numpy() for x in image]
img_key = "observation.image" img_key = "observation.image"
if video: if video:
@ -164,11 +177,15 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
shutil.rmtree(tmp_imgs_dir) shutil.rmtree(tmp_imgs_dir)
# store the reference to the video frame # store the reference to the video frame
ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)] ep_dict[img_key] = [
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
]
else: else:
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array] ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
ep_dict["observation.state"] = agent_pos ep_dict["observation.state"] = agent_pos
if keypoints_instead_of_image:
ep_dict["observation.environment_state"] = keypoints
ep_dict["action"] = actions[from_idx:to_idx] ep_dict["action"] = actions[from_idx:to_idx]
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64) ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1) ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
@ -180,7 +197,6 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
ep_dict["next.done"] = torch.cat([done[1:], done[[-1]]]) ep_dict["next.done"] = torch.cat([done[1:], done[[-1]]])
ep_dict["next.success"] = torch.cat([success[1:], success[[-1]]]) ep_dict["next.success"] = torch.cat([success[1:], success[[-1]]])
ep_dicts.append(ep_dict) ep_dicts.append(ep_dict)
data_dict = concatenate_episodes(ep_dicts) data_dict = concatenate_episodes(ep_dicts)
total_frames = data_dict["frame_index"].shape[0] total_frames = data_dict["frame_index"].shape[0]
@ -188,9 +204,10 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
return data_dict return data_dict
def to_hf_dataset(data_dict, video): def to_hf_dataset(data_dict, video, keypoints_instead_of_image: bool = False):
features = {} features = {}
if not keypoints_instead_of_image:
if video: if video:
features["observation.image"] = VideoFrame() features["observation.image"] = VideoFrame()
else: else:
@ -199,6 +216,11 @@ def to_hf_dataset(data_dict, video):
features["observation.state"] = Sequence( features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
) )
if keypoints_instead_of_image:
features["observation.environment_state"] = Sequence(
length=data_dict["observation.environment_state"].shape[1],
feature=Value(dtype="float32", id=None),
)
features["action"] = Sequence( features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None) length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
) )
@ -222,17 +244,21 @@ def from_raw_to_lerobot_format(
video: bool = True, video: bool = True,
episodes: list[int] | None = None, episodes: list[int] | None = None,
): ):
# Manually change this to True to use keypoints of the T instead of an image observation (but don't merge
# with True). Also make sure to use video = 0 in the `push_dataset_to_hub.py` script.
keypoints_instead_of_image = False
# sanity check # sanity check
check_format(raw_dir) check_format(raw_dir)
if fps is None: if fps is None:
fps = 10 fps = 10
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes) data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, keypoints_instead_of_image)
hf_dataset = to_hf_dataset(data_dict, video) hf_dataset = to_hf_dataset(data_dict, video, keypoints_instead_of_image)
episode_data_index = calculate_episode_data_index(hf_dataset) episode_data_index = calculate_episode_data_index(hf_dataset)
info = { info = {
"fps": fps, "fps": fps,
"video": video, "video": video if not keypoints_instead_of_image else 0,
} }
return hf_dataset, episode_data_index, info return hf_dataset, episode_data_index, info

View File

@ -40,6 +40,60 @@ python lerobot/scripts/push_dataset_to_hub.py \
--raw-format umi_zarr \ --raw-format umi_zarr \
--repo-id lerobot/umi_cup_in_the_wild --repo-id lerobot/umi_cup_in_the_wild
``` ```
**WARNING: Updating an existing dataset**
If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py`
before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change
intentionally or not (i.e. something not backward compatible such as modifying the reward functions used,
deleting some frames at the end of an episode, etc.). That way, people running a previous version of the
codebase won't be affected by your change and backward compatibility is maintained.
For instance, Pusht has many versions to maintain backward compatibility between LeRobot codebase versions:
- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0)
- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1)
- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2)
- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3)
- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4)
- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5) <-- last version
- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version
However, you will need to update the version of ALL the other datasets so that they have the new
`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way
that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF
dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed):
```python
import os
from huggingface_hub import create_branch, hf_hub_download
from huggingface_hub.utils._errors import RepositoryNotFoundError
from lerobot import available_datasets
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" # makes it easier to see the print-out below
NEW_CODEBASE_VERSION = "v1.5" # REPLACE THIS WITH YOUR DESIRED VERSION
for repo_id in available_datasets:
# First check if the newer version already exists.
try:
hf_hub_download(
repo_id=repo_id, repo_type="dataset", filename=".gitattributes", revision=NEW_CODEBASE_VERSION
)
print(f"Found existing branch for {repo_id}. Please contact a member of the core LeRobot team.")
print("Exiting early")
break
except RepositoryNotFoundError:
# Now create a branch.
create_branch(repo_id, repo_type="dataset", branch=NEW_CODEBASE_VERSION, revision=CODEBASE_VERSION)
print(f"{repo_id} successfully updated")
```
On the other hand, if you are pushing a new dataset, you don't need to worry about any of the instructions
above, nor to be compatible with previous codebase versions.
""" """
import argparse import argparse
@ -317,7 +371,10 @@ def main():
parser.add_argument( parser.add_argument(
"--tests-data-dir", "--tests-data-dir",
type=Path, type=Path,
help="When provided, save tests artifacts into the given directory for (e.g. `--tests-data-dir tests/data/lerobot/pusht`).", help=(
"When provided, save tests artifacts into the given directory "
"(e.g. `--tests-data-dir tests/data` will save to tests/data/{--repo-id})."
),
) )
args = parser.parse_args() args = parser.parse_args()

1137
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -47,7 +47,7 @@ huggingface-hub = {extras = ["hf-transfer"], version = "^0.23.0"}
gymnasium = ">=0.29.1" gymnasium = ">=0.29.1"
cmake = ">=3.29.0.1" cmake = ">=3.29.0.1"
gym-dora = { git = "https://github.com/dora-rs/dora-lerobot.git", subdirectory = "gym_dora", optional = true } gym-dora = { git = "https://github.com/dora-rs/dora-lerobot.git", subdirectory = "gym_dora", optional = true }
gym-pusht = { version = ">=0.1.3", optional = true} gym-pusht = { version = ">=0.1.5", optional = true}
gym-xarm = { version = ">=0.1.1", optional = true} gym-xarm = { version = ">=0.1.1", optional = true}
gym-aloha = { version = ">=0.1.1", optional = true} gym-aloha = { version = ">=0.1.1", optional = true}
pre-commit = {version = ">=3.7.0", optional = true} pre-commit = {version = ">=3.7.0", optional = true}

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:d5883aa2c8ba2bcd8d047a77064112aa5d4c1c9b8595bb28935ec93ed53627e5 oid sha256:52723265cba2ec839a5fcf75733813ecf91019ec0f7a49865fe233616e674583
size 3056 size 3056

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:0eab443dd492d0e271094290ae3cec2c9b2f4a19d35434eb5952cb37b0d40890 oid sha256:8552d4ac6b618a5b2741e174d51f1d4fc0e5f4e6cc7026bebdb6ed145373b042
size 18272 size 18320

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:8c1a72239bb56a6c5714f18d849557c89feb858840e8f86689d017bb49551379 oid sha256:a522c7815565f1f81a8bb5a853263405ab8c3b087ecbc7a3b004848891d77342
size 247 size 247

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a1cd3db853d0f92e1696fe297c550200219d85befdeb5b5eacae4b10a74d9896
size 136

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:dbf25de102227dd2d8c3b6c61e1fc25a026d44f151161b88bc9a9eb101e942e4
size 33

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:50b3c026da835560f9b87e7dfd28673e766bfb58d56c85002687d0a599b6fa43
size 3304

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:958798d23a1690449744961f8c3ed934efe950c664e5fd729468959362840218
size 20336

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:686d9d9bad8815d67597b997058d9853a04e5bdbe4eed038f4da9806f867af3d
size 1098

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f22ee3500aca1bea0afdda429e841c57a3278dfea92c79bbbf5dac5f984ed648
size 247

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:d9cc073bcb335024500fe7c823f142a3b4f038ff458d8c47fb6a6918f8f6d5fd oid sha256:b99bbb7332557d47b108fd0262d911c99f5bfce30fa3e76dc802b927284135e7
size 111338 size 111338

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:58c50ef6413b6b3acb7ad280281cdd4eba553f7d3d0b4dad20c262025d610f2b oid sha256:0f63430455e1ca7a5fe28c81a15fc0eb82758035e6b3d623e7e7952e71cb262a
size 111338 size 111338

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:bd1d26e983e2910ec170cd6ac1f4de4d7cb447ee24b516a74f42765d4894e048 oid sha256:0b88c39db5b13da646fd5876bd765213569387591d30ec665d048ae1070db0b9
size 111338 size 111338

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:e1247a9d4683520ed338f3fd410cc200999e4b82da573cd499095ba02037586f oid sha256:68eb245890f9537851ea7fb227472dcd4f1fa3820a7c3294a4989e2b9896d078
size 111338 size 111338

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:b24f3c3d41428b768082eb3b02b5e22dc9540aa4dbe756d43be214d51e97adba oid sha256:00c74e17bbf7d428b0b0869f388d348820a938c417b3c888a1384980bb53d4d0
size 111338 size 111338

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:5301dc61b585fbfbdb6ce681ffcd52fc26b64a3767567c228a9e4404f7bcb926 oid sha256:a5a7f66704640ba18f756fc44c00721c77a406f412a3a9fcc1a2b1868c978444
size 111338 size 111338

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1ae7f6a7f4ee8340ec73b0e7f1e167046af2af0a22381e0cd3ff42f311e098e0
size 794

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2eeb1b185b505450f8a2b6042537d65d2d8f5ee1396cf878a50d3d2aa3a22822
size 794

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7f2bb24887f9d4c49ad562429f419b7b66f4310a59877104a98d3c5c6ddca996
size 794

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a52fe583c816fdfb962111dd1ee1c113a5f4b9699246fab8648f89e056979f8e
size 794

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:70dbf161581b860e255573eb1ef90f4defd134d8dcf0afea16099c859c4a8f85
size 794

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:198abd0ec4231c13cadf707d553cba3860acbc74a073406ed184eab5495acdfa
size 794