Improve dataset v2 (#498)
This commit is contained in:
parent
acae4b49d2
commit
1f13bda25b
|
@ -0,0 +1,246 @@
|
|||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||
|
||||
|
||||
def create_empty_dataset(repo_id, mode):
|
||||
features = {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": [
|
||||
["x", "y"],
|
||||
],
|
||||
},
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": [
|
||||
["x", "y"],
|
||||
],
|
||||
},
|
||||
"next.reward": {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"next.success": {
|
||||
"dtype": "bool",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
}
|
||||
|
||||
if mode == "keypoints":
|
||||
features["observation.environment_state"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (16,),
|
||||
"names": [
|
||||
"keypoints",
|
||||
],
|
||||
}
|
||||
else:
|
||||
features["observation.image"] = {
|
||||
"dtype": mode,
|
||||
"shape": (3, 96, 96),
|
||||
"names": [
|
||||
"channel",
|
||||
"height",
|
||||
"width",
|
||||
],
|
||||
}
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
fps=10,
|
||||
robot_type="2d pointer",
|
||||
features=features,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
def load_raw_dataset(zarr_path, load_images=True):
|
||||
try:
|
||||
from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import (
|
||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
||||
raise e
|
||||
|
||||
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
|
||||
|
||||
env_state = zarr_data["state"][:]
|
||||
agent_pos = env_state[:, :2]
|
||||
block_pos = env_state[:, 2:4]
|
||||
block_angle = env_state[:, 4]
|
||||
|
||||
action = zarr_data["action"][:]
|
||||
|
||||
image = None
|
||||
if load_images:
|
||||
# b h w c
|
||||
image = zarr_data["img"]
|
||||
|
||||
episode_data_index = {
|
||||
"from": np.array([0] + zarr_data.meta["episode_ends"][:-1].tolist()),
|
||||
"to": zarr_data.meta["episode_ends"],
|
||||
}
|
||||
|
||||
return image, agent_pos, block_pos, block_angle, action, episode_data_index
|
||||
|
||||
|
||||
def calculate_coverage(block_pos, block_angle):
|
||||
try:
|
||||
import pymunk
|
||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||
except ModuleNotFoundError as e:
|
||||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
||||
raise e
|
||||
|
||||
num_frames = len(block_pos)
|
||||
|
||||
coverage = np.zeros((num_frames,))
|
||||
# 8 keypoints with 2 coords each
|
||||
keypoints = np.zeros((num_frames, 16))
|
||||
|
||||
# Set x, y, theta (in radians)
|
||||
goal_pos_angle = np.array([256, 256, np.pi / 4])
|
||||
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
|
||||
|
||||
for i in range(num_frames):
|
||||
space = pymunk.Space()
|
||||
space.gravity = 0, 0
|
||||
space.damping = 0
|
||||
|
||||
# Add walls.
|
||||
walls = [
|
||||
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
|
||||
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
|
||||
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
|
||||
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
|
||||
]
|
||||
space.add(*walls)
|
||||
|
||||
block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
goal_area = goal_geom.area
|
||||
coverage[i] = intersection_area / goal_area
|
||||
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
|
||||
|
||||
return coverage, keypoints
|
||||
|
||||
|
||||
def calculate_success(coverage, success_threshold):
|
||||
return coverage > success_threshold
|
||||
|
||||
|
||||
def calculate_reward(coverage, success_threshold):
|
||||
return np.clip(coverage / success_threshold, 0, 1)
|
||||
|
||||
|
||||
def populate_dataset(dataset, episode_data_index, episodes, image, state, env_state, action, reward, success):
|
||||
if episodes is None:
|
||||
episodes = range(len(episode_data_index["from"]))
|
||||
|
||||
for ep_idx in episodes:
|
||||
from_idx = episode_data_index["from"][ep_idx]
|
||||
to_idx = episode_data_index["to"][ep_idx]
|
||||
num_frames = to_idx - from_idx
|
||||
|
||||
for frame_idx in range(num_frames):
|
||||
i = from_idx + frame_idx
|
||||
|
||||
frame = {
|
||||
"action": torch.from_numpy(action[i]),
|
||||
# Shift reward and success by +1 until the last item of the episode
|
||||
"next.reward": reward[i + (frame_idx < num_frames - 1)],
|
||||
"next.success": success[i + (frame_idx < num_frames - 1)],
|
||||
}
|
||||
|
||||
frame["observation.state"] = torch.from_numpy(state[i])
|
||||
|
||||
if env_state is not None:
|
||||
frame["observation.environment_state"] = torch.from_numpy(env_state[i])
|
||||
|
||||
if image is not None:
|
||||
frame["observation.image"] = torch.from_numpy(image[i])
|
||||
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode(task="Push the T-shaped blue block onto the T-shaped green target surface.")
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def port_pusht(raw_dir, repo_id, episodes=None, mode="video", push_to_hub=True):
|
||||
if mode not in ["video", "image", "keypoints"]:
|
||||
raise ValueError(mode)
|
||||
|
||||
if (LEROBOT_HOME / repo_id).exists():
|
||||
shutil.rmtree(LEROBOT_HOME / repo_id)
|
||||
|
||||
raw_dir = Path(raw_dir)
|
||||
if not raw_dir.exists():
|
||||
download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")
|
||||
|
||||
image, agent_pos, block_pos, block_angle, action, episode_data_index = load_raw_dataset(
|
||||
zarr_path=raw_dir / "pusht_cchi_v7_replay.zarr"
|
||||
)
|
||||
|
||||
# Calculate success and reward based on the overlapping area
|
||||
# of the T-object and the T-area.
|
||||
coverage, keypoints = calculate_coverage(block_pos, block_angle)
|
||||
success = calculate_success(coverage, success_threshold=0.95)
|
||||
reward = calculate_reward(coverage, success_threshold=0.95)
|
||||
|
||||
dataset = create_empty_dataset(repo_id, mode)
|
||||
dataset = populate_dataset(
|
||||
dataset,
|
||||
episode_data_index,
|
||||
episodes,
|
||||
image=None if mode == "keypoints" else image,
|
||||
state=agent_pos,
|
||||
env_state=keypoints if mode == "keypoints" else None,
|
||||
action=action,
|
||||
reward=reward,
|
||||
success=success,
|
||||
)
|
||||
dataset.consolidate()
|
||||
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# To try this script, modify the repo id with your own HuggingFace user (e.g cadene/pusht)
|
||||
repo_id = "lerobot/pusht"
|
||||
|
||||
episodes = None
|
||||
# Uncomment if you want to try with a subset (episode 0 and 1)
|
||||
# episodes = [0, 1]
|
||||
|
||||
modes = ["video", "image", "keypoints"]
|
||||
# Uncomment if you want to try with a specific mode
|
||||
# modes = ["video"]
|
||||
# modes = ["image"]
|
||||
# modes = ["keypoints"]
|
||||
|
||||
for mode in ["video", "image", "keypoints"]:
|
||||
if mode in ["image", "keypoints"]:
|
||||
repo_id += f"_{mode}"
|
||||
|
||||
# download and load raw dataset, create LeRobotDataset, populate it, push to hub
|
||||
port_pusht("data/lerobot-raw/pusht_raw", repo_id=repo_id, mode=mode, episodes=episodes)
|
||||
|
||||
# Uncomment if you want to loal the local dataset and explore it
|
||||
# dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True)
|
||||
# breakpoint()
|
|
@ -280,6 +280,8 @@ class LeRobotDatasetMetadata:
|
|||
obj.repo_id = repo_id
|
||||
obj.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
|
||||
|
||||
obj.root.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
if robot is not None:
|
||||
features = get_features_from_robot(robot, use_videos)
|
||||
robot_type = robot.robot_type
|
||||
|
@ -293,6 +295,7 @@ class LeRobotDatasetMetadata:
|
|||
"Dataset features must either come from a Robot or explicitly passed upon creation."
|
||||
)
|
||||
else:
|
||||
# TODO(aliberts, rcadene): implement sanity check for features
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
|
||||
obj.tasks, obj.stats, obj.episodes = {}, {}, []
|
||||
|
@ -424,11 +427,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
self.delta_indices = None
|
||||
self.local_files_only = local_files_only
|
||||
self.consolidated = True
|
||||
|
||||
# Unused attributes
|
||||
self.image_writer = None
|
||||
self.episode_buffer = {}
|
||||
self.episode_buffer = None
|
||||
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
|
@ -451,12 +453,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
# Available stats implies all videos have been encoded and dataset is iterable
|
||||
self.consolidated = self.meta.stats is not None
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
tags: list | None = None,
|
||||
text: str | None = None,
|
||||
license: str | None = "mit",
|
||||
license: str | None = "apache-2.0",
|
||||
push_videos: bool = True,
|
||||
private: bool = False,
|
||||
) -> None:
|
||||
if not self.consolidated:
|
||||
raise RuntimeError(
|
||||
|
@ -468,7 +474,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
if not push_videos:
|
||||
ignore_patterns.append("videos/")
|
||||
|
||||
create_repo(self.repo_id, repo_type="dataset", exist_ok=True)
|
||||
create_repo(
|
||||
repo_id=self.repo_id,
|
||||
private=private,
|
||||
repo_type="dataset",
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
upload_folder(
|
||||
repo_id=self.repo_id,
|
||||
folder_path=self.root,
|
||||
|
@ -658,7 +670,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
|
||||
return {
|
||||
"size": 0,
|
||||
**{key: [] if key != "episode_index" else current_ep_idx for key in self.features},
|
||||
**{key: current_ep_idx if key == "episode_index" else [] for key in self.features},
|
||||
}
|
||||
|
||||
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
|
@ -681,8 +693,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
|
||||
then needs to be called.
|
||||
"""
|
||||
# TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch,
|
||||
# check the dtype and shape matches, etc.
|
||||
|
||||
if self.episode_buffer is None:
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
|
||||
frame_index = self.episode_buffer["size"]
|
||||
timestamp = frame["timestamp"] if "timestamp" in frame else frame_index / self.fps
|
||||
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
|
||||
self.episode_buffer["frame_index"].append(frame_index)
|
||||
self.episode_buffer["timestamp"].append(timestamp)
|
||||
|
||||
|
@ -723,6 +741,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
# TODO(aliberts): Add option to use existing episode_index
|
||||
raise NotImplementedError()
|
||||
|
||||
if episode_length == 0:
|
||||
raise ValueError(
|
||||
"You must add one or several frames with `add_frame` before calling `add_episode`."
|
||||
)
|
||||
|
||||
task_index = self.meta.get_task_index(task)
|
||||
|
||||
if not set(episode_buffer.keys()) == set(self.features):
|
||||
|
@ -781,7 +804,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
# Reset the buffer
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
|
||||
def start_image_writer(self, num_processes: int = 0, num_threads: int = 1) -> None:
|
||||
def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
|
||||
if isinstance(self.image_writer, AsyncImageWriter):
|
||||
logging.warning(
|
||||
"You are starting a new AsyncImageWriter that is replacing an already exising one in the dataset."
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import textwrap
|
||||
import warnings
|
||||
from itertools import accumulate
|
||||
from pathlib import Path
|
||||
|
@ -139,6 +140,8 @@ def load_info(local_dir: Path) -> dict:
|
|||
|
||||
|
||||
def load_stats(local_dir: Path) -> dict:
|
||||
if not (local_dir / STATS_PATH).exists():
|
||||
return None
|
||||
stats = load_json(local_dir / STATS_PATH)
|
||||
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
|
@ -186,17 +189,37 @@ def _get_major_minor(version: str) -> tuple[int]:
|
|||
return int(split[0]), int(split[1])
|
||||
|
||||
|
||||
class BackwardCompatibilityError(Exception):
|
||||
def __init__(self, repo_id, version):
|
||||
message = textwrap.dedent(f"""
|
||||
BackwardCompatibilityError: The dataset you requested ({repo_id}) is in {version} format.
|
||||
|
||||
We introduced a new format since v2.0 which is not backward compatible with v1.x.
|
||||
Please, use our conversion script. Modify the following command with your own task description:
|
||||
```
|
||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
|
||||
--repo-id {repo_id} \\
|
||||
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
|
||||
```
|
||||
|
||||
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.",
|
||||
"Insert the peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.",
|
||||
"Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.",
|
||||
"Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ...
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on Discord ('https://discord.com/invite/s3KuuzsPFb')
|
||||
or open an issue on GitHub.
|
||||
""")
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def check_version_compatibility(
|
||||
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
|
||||
) -> None:
|
||||
current_major, _ = _get_major_minor(current_version)
|
||||
major_to_check, _ = _get_major_minor(version_to_check)
|
||||
if major_to_check < current_major and enforce_breaking_major:
|
||||
raise ValueError(
|
||||
f"""The dataset you requested ({repo_id}) is in {version_to_check} format. We introduced a new
|
||||
format with v2.0 that is not backward compatible. Please use our conversion script
|
||||
first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format."""
|
||||
)
|
||||
raise BackwardCompatibilityError(repo_id, version_to_check)
|
||||
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
|
||||
warnings.warn(
|
||||
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
|
||||
|
@ -207,18 +230,16 @@ def check_version_compatibility(
|
|||
)
|
||||
|
||||
|
||||
def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> str:
|
||||
num_version = float(version.strip("v"))
|
||||
if num_version < 2 and enforce_v2:
|
||||
raise ValueError(
|
||||
f"""The dataset you requested ({repo_id}) is in {version} format. We introduced a new
|
||||
format with v2.0 that is not backward compatible. Please use our conversion script
|
||||
first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format."""
|
||||
)
|
||||
def get_hub_safe_version(repo_id: str, version: str) -> str:
|
||||
api = HfApi()
|
||||
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
branches = [b.name for b in dataset_info.branches]
|
||||
if version not in branches:
|
||||
num_version = float(version.strip("v"))
|
||||
hub_num_versions = [float(v.strip("v")) for v in branches if v.startswith("v")]
|
||||
if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions):
|
||||
raise BackwardCompatibilityError(repo_id, version)
|
||||
|
||||
warnings.warn(
|
||||
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
|
||||
codebase. The following versions are available: {branches}.
|
||||
|
@ -461,6 +482,7 @@ def create_lerobot_dataset_card(
|
|||
}
|
||||
]
|
||||
card.data.task_categories = ["robotics"]
|
||||
card.data.license = license
|
||||
card.data.tags = ["LeRobot"]
|
||||
if license:
|
||||
card.data.license = license
|
||||
|
|
|
@ -441,7 +441,7 @@ def convert_dataset(
|
|||
arxiv: str | None = None,
|
||||
test_branch: str | None = None,
|
||||
):
|
||||
v1 = get_hub_safe_version(repo_id, V16, enforce_v2=False)
|
||||
v1 = get_hub_safe_version(repo_id, V16)
|
||||
v1x_dir = local_dir / V16 / repo_id
|
||||
v20_dir = local_dir / V20 / repo_id
|
||||
v1x_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
|
|
@ -17,6 +17,7 @@ from termcolor import colored
|
|||
|
||||
from lerobot.common.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import get_features_from_robot
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
|
@ -330,3 +331,21 @@ def sanity_check_dataset_name(repo_id, policy):
|
|||
raise ValueError(
|
||||
f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})."
|
||||
)
|
||||
|
||||
|
||||
def sanity_check_dataset_robot_compatibility(dataset, robot, fps, use_videos):
|
||||
fields = [
|
||||
("robot_type", dataset.meta.info["robot_type"], robot.robot_type),
|
||||
("fps", dataset.meta.info["fps"], fps),
|
||||
("features", dataset.features, get_features_from_robot(robot, use_videos)),
|
||||
]
|
||||
|
||||
mismatches = []
|
||||
for field, dataset_value, present_value in fields:
|
||||
if dataset_value != present_value:
|
||||
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
|
||||
|
||||
if mismatches:
|
||||
raise ValueError(
|
||||
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
|
||||
)
|
||||
|
|
|
@ -115,6 +115,7 @@ from lerobot.common.robot_devices.control_utils import (
|
|||
record_episode,
|
||||
reset_environment,
|
||||
sanity_check_dataset_name,
|
||||
sanity_check_dataset_robot_compatibility,
|
||||
stop_recording,
|
||||
warmup_record,
|
||||
)
|
||||
|
@ -207,6 +208,9 @@ def record(
|
|||
num_image_writer_threads_per_camera: int = 4,
|
||||
display_cameras: bool = True,
|
||||
play_sounds: bool = True,
|
||||
resume: bool = False,
|
||||
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
|
||||
local_files_only: bool = False,
|
||||
) -> LeRobotDataset:
|
||||
# TODO(rcadene): Add option to record logs
|
||||
listener = None
|
||||
|
@ -232,17 +236,29 @@ def record(
|
|||
f"There is a mismatch between the provided fps ({fps}) and the one from policy config ({policy_fps})."
|
||||
)
|
||||
|
||||
# Create empty dataset or load existing saved episodes
|
||||
sanity_check_dataset_name(repo_id, policy)
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id,
|
||||
fps,
|
||||
root=root,
|
||||
robot=robot,
|
||||
use_videos=video,
|
||||
image_writer_processes=num_image_writer_processes,
|
||||
image_writer_threads=num_image_writer_threads_per_camera,
|
||||
)
|
||||
if resume:
|
||||
dataset = LeRobotDataset(
|
||||
repo_id,
|
||||
root=root,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
dataset.start_image_writer(
|
||||
num_processes=num_image_writer_processes,
|
||||
num_threads=num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video)
|
||||
else:
|
||||
# Create empty dataset or load existing saved episodes
|
||||
sanity_check_dataset_name(repo_id, policy)
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id,
|
||||
fps,
|
||||
root=root,
|
||||
robot=robot,
|
||||
use_videos=video,
|
||||
image_writer_processes=num_image_writer_processes,
|
||||
image_writer_threads=num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
@ -270,8 +286,7 @@ def record(
|
|||
# if multi_task:
|
||||
# task = input("Enter your task description: ")
|
||||
|
||||
episode_index = dataset.episode_buffer["episode_index"]
|
||||
log_say(f"Recording episode {episode_index}", play_sounds)
|
||||
log_say(f"Recording episode {dataset.num_episodes}", play_sounds)
|
||||
record_episode(
|
||||
dataset=dataset,
|
||||
robot=robot,
|
||||
|
@ -289,7 +304,7 @@ def record(
|
|||
# TODO(rcadene): add an option to enable teleoperation during reset
|
||||
# Skip reset for the last episode to be recorded
|
||||
if not events["stop_recording"] and (
|
||||
(episode_index < num_episodes - 1) or events["rerecord_episode"]
|
||||
(dataset.num_episodes < num_episodes - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment", play_sounds)
|
||||
reset_environment(robot, events, reset_time_s)
|
||||
|
|
|
@ -117,10 +117,14 @@ def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str
|
|||
|
||||
|
||||
def push_dataset_card_to_hub(
|
||||
repo_id: str, revision: str | None, tags: list | None = None, text: str | None = None
|
||||
repo_id: str,
|
||||
revision: str | None,
|
||||
tags: list | None = None,
|
||||
text: str | None = None,
|
||||
license: str = "apache-2.0",
|
||||
):
|
||||
"""Creates and pushes a LeRobotDataset Card with appropriate tags to easily find it on the hub."""
|
||||
card = create_lerobot_dataset_card(tags=tags, text=text)
|
||||
card = create_lerobot_dataset_card(tags=tags, text=text, license=license)
|
||||
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=revision)
|
||||
|
||||
|
||||
|
|
|
@ -325,7 +325,7 @@ def lerobot_dataset_metadata_factory(
|
|||
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
|
||||
) as mock_snapshot_download_patch,
|
||||
):
|
||||
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
|
||||
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version: version
|
||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
||||
|
||||
return LeRobotDatasetMetadata(repo_id=repo_id, root=root, local_files_only=local_files_only)
|
||||
|
|
|
@ -275,22 +275,25 @@ def test_resume_record(tmpdir, request, robot_type, mock):
|
|||
root = Path(tmpdir) / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
dataset = record(
|
||||
robot,
|
||||
root,
|
||||
repo_id,
|
||||
single_task,
|
||||
fps=1,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
num_episodes=1,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
run_compute_stats=False,
|
||||
)
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
record_kwargs = {
|
||||
"robot": robot,
|
||||
"root": root,
|
||||
"repo_id": repo_id,
|
||||
"single_task": single_task,
|
||||
"fps": 1,
|
||||
"warmup_time_s": 0,
|
||||
"episode_time_s": 1,
|
||||
"push_to_hub": False,
|
||||
"video": False,
|
||||
"display_cameras": False,
|
||||
"play_sounds": False,
|
||||
"run_compute_stats": False,
|
||||
"local_files_only": True,
|
||||
"num_episodes": 1,
|
||||
}
|
||||
|
||||
dataset = record(**record_kwargs)
|
||||
assert len(dataset) == 1, f"`dataset` should contain 1 frame, not {len(dataset)}"
|
||||
|
||||
# init_dataset_return_value = {}
|
||||
|
||||
|
@ -300,22 +303,13 @@ def test_resume_record(tmpdir, request, robot_type, mock):
|
|||
# return init_dataset_return_value
|
||||
|
||||
# with patch("lerobot.scripts.control_robot.init_dataset", wraps=wrapped_init_dataset):
|
||||
dataset = record(
|
||||
robot,
|
||||
root,
|
||||
repo_id,
|
||||
single_task,
|
||||
fps=1,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
num_episodes=2,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
run_compute_stats=False,
|
||||
)
|
||||
assert len(dataset) == 2, "`dataset` should contain only 1 frame"
|
||||
|
||||
with pytest.raises(FileExistsError):
|
||||
# Dataset already exists, but resume=False by default
|
||||
record(**record_kwargs)
|
||||
|
||||
dataset = record(**record_kwargs, resume=True)
|
||||
assert len(dataset) == 2, f"`dataset` should contain 2 frames, not {len(dataset)}"
|
||||
# assert (
|
||||
# init_dataset_return_value["num_episodes"] == 2
|
||||
# ), "`init_dataset` should load the previous episode"
|
||||
|
|
Loading…
Reference in New Issue