Include `observation.environment_state` with keypoints in PushT dataset (#303)
Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
parent
7bd5ab16d1
commit
a4d77b99f0
|
@ -70,6 +70,8 @@ available_datasets_per_env = {
|
|||
"lerobot/aloha_sim_transfer_cube_human_image",
|
||||
"lerobot/aloha_sim_transfer_cube_scripted_image",
|
||||
],
|
||||
# TODO(alexander-soare): Add "lerobot/pusht_keypoints". Right now we can't because this is too tightly
|
||||
# coupled with tests.
|
||||
"pusht": ["lerobot/pusht", "lerobot/pusht_image"],
|
||||
"xarm": [
|
||||
"lerobot/xarm_lift_medium",
|
||||
|
|
|
@ -36,7 +36,7 @@ from lerobot.common.datasets.utils import (
|
|||
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
|
||||
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
CODEBASE_VERSION = "v1.4"
|
||||
CODEBASE_VERSION = "v1.5"
|
||||
|
||||
|
||||
class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
|
|
@ -54,7 +54,14 @@ def check_format(raw_dir):
|
|||
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||
|
||||
|
||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||
def load_from_raw(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int,
|
||||
video: bool,
|
||||
episodes: list[int] | None = None,
|
||||
keypoints_instead_of_image: bool = False,
|
||||
):
|
||||
try:
|
||||
import pymunk
|
||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||
|
@ -105,10 +112,11 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
|||
assert (episode_ids[from_idx:to_idx] == ep_idx).all()
|
||||
|
||||
# get image
|
||||
image = imgs[from_idx:to_idx]
|
||||
assert image.min() >= 0.0
|
||||
assert image.max() <= 255.0
|
||||
image = image.type(torch.uint8)
|
||||
if not keypoints_instead_of_image:
|
||||
image = imgs[from_idx:to_idx]
|
||||
assert image.min() >= 0.0
|
||||
assert image.max() <= 255.0
|
||||
image = image.type(torch.uint8)
|
||||
|
||||
# get state
|
||||
state = states[from_idx:to_idx]
|
||||
|
@ -116,9 +124,11 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
|||
block_pos = state[:, 2:4]
|
||||
block_angle = state[:, 4]
|
||||
|
||||
# get reward, success, done
|
||||
# get reward, success, done, and (maybe) keypoints
|
||||
reward = torch.zeros(num_frames)
|
||||
success = torch.zeros(num_frames, dtype=torch.bool)
|
||||
if keypoints_instead_of_image:
|
||||
keypoints = torch.zeros(num_frames, 16) # 8 keypoints each with 2 coords
|
||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
for i in range(num_frames):
|
||||
space = pymunk.Space()
|
||||
|
@ -134,7 +144,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
|||
]
|
||||
space.add(*walls)
|
||||
|
||||
block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||
block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
|
@ -142,33 +152,40 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
|||
coverage = intersection_area / goal_area
|
||||
reward[i] = np.clip(coverage / success_threshold, 0, 1)
|
||||
success[i] = coverage > success_threshold
|
||||
if keypoints_instead_of_image:
|
||||
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done[-1] = True
|
||||
|
||||
ep_dict = {}
|
||||
|
||||
imgs_array = [x.numpy() for x in image]
|
||||
img_key = "observation.image"
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
if not keypoints_instead_of_image:
|
||||
imgs_array = [x.numpy() for x in image]
|
||||
img_key = "observation.image"
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
||||
]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
ep_dict["observation.state"] = agent_pos
|
||||
if keypoints_instead_of_image:
|
||||
ep_dict["observation.environment_state"] = keypoints
|
||||
ep_dict["action"] = actions[from_idx:to_idx]
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
|
@ -180,7 +197,6 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
|||
ep_dict["next.done"] = torch.cat([done[1:], done[[-1]]])
|
||||
ep_dict["next.success"] = torch.cat([success[1:], success[[-1]]])
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
|
@ -188,17 +204,23 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
|||
return data_dict
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video):
|
||||
def to_hf_dataset(data_dict, video, keypoints_instead_of_image: bool = False):
|
||||
features = {}
|
||||
|
||||
if video:
|
||||
features["observation.image"] = VideoFrame()
|
||||
else:
|
||||
features["observation.image"] = Image()
|
||||
if not keypoints_instead_of_image:
|
||||
if video:
|
||||
features["observation.image"] = VideoFrame()
|
||||
else:
|
||||
features["observation.image"] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if keypoints_instead_of_image:
|
||||
features["observation.environment_state"] = Sequence(
|
||||
length=data_dict["observation.environment_state"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
|
@ -222,17 +244,21 @@ def from_raw_to_lerobot_format(
|
|||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
):
|
||||
# Manually change this to True to use keypoints of the T instead of an image observation (but don't merge
|
||||
# with True). Also make sure to use video = 0 in the `push_dataset_to_hub.py` script.
|
||||
keypoints_instead_of_image = False
|
||||
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
if fps is None:
|
||||
fps = 10
|
||||
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, keypoints_instead_of_image)
|
||||
hf_dataset = to_hf_dataset(data_dict, video, keypoints_instead_of_image)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
"video": video if not keypoints_instead_of_image else 0,
|
||||
}
|
||||
return hf_dataset, episode_data_index, info
|
||||
|
|
|
@ -40,6 +40,60 @@ python lerobot/scripts/push_dataset_to_hub.py \
|
|||
--raw-format umi_zarr \
|
||||
--repo-id lerobot/umi_cup_in_the_wild
|
||||
```
|
||||
|
||||
**WARNING: Updating an existing dataset**
|
||||
|
||||
If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py`
|
||||
before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change
|
||||
intentionally or not (i.e. something not backward compatible such as modifying the reward functions used,
|
||||
deleting some frames at the end of an episode, etc.). That way, people running a previous version of the
|
||||
codebase won't be affected by your change and backward compatibility is maintained.
|
||||
|
||||
For instance, Pusht has many versions to maintain backward compatibility between LeRobot codebase versions:
|
||||
- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0)
|
||||
- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1)
|
||||
- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2)
|
||||
- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3)
|
||||
- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4)
|
||||
- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5) <-- last version
|
||||
- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version
|
||||
|
||||
However, you will need to update the version of ALL the other datasets so that they have the new
|
||||
`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way
|
||||
that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF
|
||||
dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed):
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
from huggingface_hub import create_branch, hf_hub_download
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
|
||||
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" # makes it easier to see the print-out below
|
||||
|
||||
NEW_CODEBASE_VERSION = "v1.5" # REPLACE THIS WITH YOUR DESIRED VERSION
|
||||
|
||||
for repo_id in available_datasets:
|
||||
# First check if the newer version already exists.
|
||||
try:
|
||||
hf_hub_download(
|
||||
repo_id=repo_id, repo_type="dataset", filename=".gitattributes", revision=NEW_CODEBASE_VERSION
|
||||
)
|
||||
print(f"Found existing branch for {repo_id}. Please contact a member of the core LeRobot team.")
|
||||
print("Exiting early")
|
||||
break
|
||||
except RepositoryNotFoundError:
|
||||
# Now create a branch.
|
||||
create_branch(repo_id, repo_type="dataset", branch=NEW_CODEBASE_VERSION, revision=CODEBASE_VERSION)
|
||||
print(f"{repo_id} successfully updated")
|
||||
|
||||
```
|
||||
|
||||
On the other hand, if you are pushing a new dataset, you don't need to worry about any of the instructions
|
||||
above, nor to be compatible with previous codebase versions.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
@ -317,7 +371,10 @@ def main():
|
|||
parser.add_argument(
|
||||
"--tests-data-dir",
|
||||
type=Path,
|
||||
help="When provided, save tests artifacts into the given directory for (e.g. `--tests-data-dir tests/data/lerobot/pusht`).",
|
||||
help=(
|
||||
"When provided, save tests artifacts into the given directory "
|
||||
"(e.g. `--tests-data-dir tests/data` will save to tests/data/{--repo-id})."
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -47,7 +47,7 @@ huggingface-hub = {extras = ["hf-transfer"], version = "^0.23.0"}
|
|||
gymnasium = ">=0.29.1"
|
||||
cmake = ">=3.29.0.1"
|
||||
gym-dora = { git = "https://github.com/dora-rs/dora-lerobot.git", subdirectory = "gym_dora", optional = true }
|
||||
gym-pusht = { version = ">=0.1.3", optional = true}
|
||||
gym-pusht = { version = ">=0.1.5", optional = true}
|
||||
gym-xarm = { version = ">=0.1.1", optional = true}
|
||||
gym-aloha = { version = ">=0.1.1", optional = true}
|
||||
pre-commit = {version = ">=3.7.0", optional = true}
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d5883aa2c8ba2bcd8d047a77064112aa5d4c1c9b8595bb28935ec93ed53627e5
|
||||
oid sha256:52723265cba2ec839a5fcf75733813ecf91019ec0f7a49865fe233616e674583
|
||||
size 3056
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0eab443dd492d0e271094290ae3cec2c9b2f4a19d35434eb5952cb37b0d40890
|
||||
size 18272
|
||||
oid sha256:8552d4ac6b618a5b2741e174d51f1d4fc0e5f4e6cc7026bebdb6ed145373b042
|
||||
size 18320
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8c1a72239bb56a6c5714f18d849557c89feb858840e8f86689d017bb49551379
|
||||
oid sha256:a522c7815565f1f81a8bb5a853263405ab8c3b087ecbc7a3b004848891d77342
|
||||
size 247
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a1cd3db853d0f92e1696fe297c550200219d85befdeb5b5eacae4b10a74d9896
|
||||
size 136
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:dbf25de102227dd2d8c3b6c61e1fc25a026d44f151161b88bc9a9eb101e942e4
|
||||
size 33
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:50b3c026da835560f9b87e7dfd28673e766bfb58d56c85002687d0a599b6fa43
|
||||
size 3304
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:958798d23a1690449744961f8c3ed934efe950c664e5fd729468959362840218
|
||||
size 20336
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:686d9d9bad8815d67597b997058d9853a04e5bdbe4eed038f4da9806f867af3d
|
||||
size 1098
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f22ee3500aca1bea0afdda429e841c57a3278dfea92c79bbbf5dac5f984ed648
|
||||
size 247
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d9cc073bcb335024500fe7c823f142a3b4f038ff458d8c47fb6a6918f8f6d5fd
|
||||
oid sha256:b99bbb7332557d47b108fd0262d911c99f5bfce30fa3e76dc802b927284135e7
|
||||
size 111338
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:58c50ef6413b6b3acb7ad280281cdd4eba553f7d3d0b4dad20c262025d610f2b
|
||||
oid sha256:0f63430455e1ca7a5fe28c81a15fc0eb82758035e6b3d623e7e7952e71cb262a
|
||||
size 111338
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bd1d26e983e2910ec170cd6ac1f4de4d7cb447ee24b516a74f42765d4894e048
|
||||
oid sha256:0b88c39db5b13da646fd5876bd765213569387591d30ec665d048ae1070db0b9
|
||||
size 111338
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e1247a9d4683520ed338f3fd410cc200999e4b82da573cd499095ba02037586f
|
||||
oid sha256:68eb245890f9537851ea7fb227472dcd4f1fa3820a7c3294a4989e2b9896d078
|
||||
size 111338
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b24f3c3d41428b768082eb3b02b5e22dc9540aa4dbe756d43be214d51e97adba
|
||||
oid sha256:00c74e17bbf7d428b0b0869f388d348820a938c417b3c888a1384980bb53d4d0
|
||||
size 111338
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5301dc61b585fbfbdb6ce681ffcd52fc26b64a3767567c228a9e4404f7bcb926
|
||||
oid sha256:a5a7f66704640ba18f756fc44c00721c77a406f412a3a9fcc1a2b1868c978444
|
||||
size 111338
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1ae7f6a7f4ee8340ec73b0e7f1e167046af2af0a22381e0cd3ff42f311e098e0
|
||||
size 794
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2eeb1b185b505450f8a2b6042537d65d2d8f5ee1396cf878a50d3d2aa3a22822
|
||||
size 794
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7f2bb24887f9d4c49ad562429f419b7b66f4310a59877104a98d3c5c6ddca996
|
||||
size 794
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a52fe583c816fdfb962111dd1ee1c113a5f4b9699246fab8648f89e056979f8e
|
||||
size 794
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:70dbf161581b860e255573eb1ef90f4defd134d8dcf0afea16099c859c4a8f85
|
||||
size 794
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:198abd0ec4231c13cadf707d553cba3860acbc74a073406ed184eab5495acdfa
|
||||
size 794
|
Loading…
Reference in New Issue