Almost done

This commit is contained in:
Remi Cadene 2024-07-10 00:07:40 +02:00
parent 798373e7bf
commit 52e760a88e
4 changed files with 58 additions and 55 deletions

View File

@ -34,8 +34,8 @@ def make_robot(name):
), ),
}, },
cameras={ cameras={
"macbookpro": OpenCVCamera(1, fps=30, width=640, height=480), "laptop": OpenCVCamera(1, fps=30, width=640, height=480),
"iphone": OpenCVCamera(2, fps=30, width=640, height=480), "phone": OpenCVCamera(2, fps=30, width=640, height=480),
}, },
) )
else: else:

View File

@ -209,7 +209,8 @@ def record_dataset(
# Save images using threads to reach high fps (30 and more) # Save images using threads to reach high fps (30 and more)
# Using `with` to exist smoothly if an execption is raised. # Using `with` to exist smoothly if an execption is raised.
# Using only 4 worker threads to avoid blocking the main thread. # Using only 4 worker threads to avoid blocking the main thread.
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
futures = []
# Execute a few seconds without recording data, to give times # Execute a few seconds without recording data, to give times
# to the robot devices to connect and start synchronizing. # to the robot devices to connect and start synchronizing.
@ -236,6 +237,7 @@ def record_dataset(
# Start recording all episodes # Start recording all episodes
ep_dicts = [] ep_dicts = []
for episode_index in range(num_episodes): for episode_index in range(num_episodes):
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
ep_dict = {} ep_dict = {}
frame_index = 0 frame_index = 0
timestamp = 0 timestamp = 0
@ -254,7 +256,8 @@ def record_dataset(
not_image_keys = [key for key in observation if "image" not in key] not_image_keys = [key for key in observation if "image" not in key]
for key in image_keys: for key in image_keys:
executor.submit(save_image, observation[key], key, frame_index, episode_index, videos_dir) future = executor.submit(save_image, observation[key], key, frame_index, episode_index, videos_dir)
futures.append(future)
for key in not_image_keys: for key in not_image_keys:
if key not in ep_dict: if key not in ep_dict:
@ -338,6 +341,7 @@ def record_dataset(
videos_dir=videos_dir, videos_dir=videos_dir,
) )
stats = compute_stats(lerobot_dataset) if run_compute_stats else {} stats = compute_stats(lerobot_dataset) if run_compute_stats else {}
lerobot_dataset.stats = stats
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
hf_dataset.save_to_disk(str(local_dir / "train")) hf_dataset.save_to_disk(str(local_dir / "train"))

7
poetry.lock generated
View File

@ -2375,9 +2375,8 @@ description = "Nvidia JIT LTO Library"
optional = false optional = false
python-versions = ">=3" python-versions = ">=3"
files = [ files = [
{file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_aarch64.whl", hash = "sha256:004186d5ea6a57758fd6d57052a123c73a4815adf365eb8dd6a85c9eaa7535ff"}, {file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f9b37bc5c8cf7509665cb6ada5aaa0ce65618f2332b7d3e78e9790511f111212"},
{file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d9714f27c1d0f0895cd8915c07a87a1d0029a0aa36acaf9156952ec2a8a12189"}, {file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-win_amd64.whl", hash = "sha256:e782564d705ff0bf61ac3e1bf730166da66dd2fe9012f111ede5fc49b64ae697"},
{file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-win_amd64.whl", hash = "sha256:c3401dc8543b52d3a8158007a0c1ab4e9c768fcbd24153a48c86972102197ddd"},
] ]
[[package]] [[package]]
@ -4364,4 +4363,4 @@ xarm = ["gym-xarm"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.10,<3.13" python-versions = ">=3.10,<3.13"
content-hash = "81dc830d3d36c67e2fe2aea6cc30829eb2977edbf49a037df21a5f329a01aee5" content-hash = "223a6496a630da8181f119634f96bed3e0de3aaca714f1f1abd7edd562e3f1c6"

View File

@ -19,14 +19,14 @@ def test_teleoperate():
def test_record_dataset_and_replay_episode_and_run_policy(tmpdir): def test_record_dataset_and_replay_episode_and_run_policy(tmpdir):
robot_name = "koch" robot_name = "koch"
env_name = "koch_real" env_name = "koch_real"
policy_name = "act_real" policy_name = "act_koch_real"
#root = Path(tmpdir) #root = Path(tmpdir)
root = Path("tmp/data") root = Path("tmp/data")
repo_id = "lerobot/debug" repo_id = "lerobot/debug"
robot = make_robot(robot_name) robot = make_robot(robot_name)
dataset = record_dataset(robot, fps=30, root=root, repo_id=repo_id, warmup_time_s=2, episode_time_s=2, num_episodes=2) dataset = record_dataset(robot, fps=30, root=root, repo_id=repo_id, warmup_time_s=1, episode_time_s=1, num_episodes=2)
replay_episode(robot, episode=0, fps=30, root=root, repo_id=repo_id) replay_episode(robot, episode=0, fps=30, root=root, repo_id=repo_id)