fixed replay function

This commit is contained in:
Michel Aractingi 2024-10-18 10:25:09 +02:00
parent 498d9ef35c
commit 22df0b381d
1 changed files with 32 additions and 26 deletions

View File

@ -44,7 +44,7 @@ python lerobot/scripts/control_robot.py replay \
--fps 30 \ --fps 30 \
--root tmp/data \ --root tmp/data \
--repo-id $USER/koch_test \ --repo-id $USER/koch_test \
--episode 0 --episodes 0
``` ```
- Record a full dataset in order to train a policy, - Record a full dataset in order to train a policy,
@ -115,8 +115,6 @@ from lerobot.scripts.push_dataset_to_hub import (
######################################################################################## ########################################################################################
# Utilities # Utilities
######################################################################################## ########################################################################################
def say(text, blocking=False): def say(text, blocking=False):
# Check if mac, linux, or windows. # Check if mac, linux, or windows.
if platform.system() == "Darwin": if platform.system() == "Darwin":
@ -571,14 +569,17 @@ def record(
return lerobot_dataset return lerobot_dataset
def replay(robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug"): def replay(env, episodes: list, fps: int | None = None, root="data", repo_id="lerobot/debug"):
# TODO(rcadene): Add option to record logs
env = env()
local_dir = Path(root) / repo_id local_dir = Path(root) / repo_id
if not local_dir.exists(): if not local_dir.exists():
raise ValueError(local_dir) raise ValueError(local_dir)
dataset = LeRobotDataset(repo_id, root=root) dataset = LeRobotDataset(repo_id, root=root)
items = dataset.hf_dataset.select_columns("action") items = dataset.hf_dataset.select_columns("action")
for episode in episodes:
env.reset()
from_idx = dataset.episode_data_index["from"][episode].item() from_idx = dataset.episode_data_index["from"][episode].item()
to_idx = dataset.episode_data_index["to"][episode].item() to_idx = dataset.episode_data_index["to"][episode].item()
@ -588,11 +589,15 @@ def replay(robot: Robot, episode: int, fps: int | None = None, root="data", repo
start_episode_t = time.perf_counter() start_episode_t = time.perf_counter()
action = items[idx]["action"] action = items[idx]["action"]
robot.send_action(action)
env.step(action.unsqueeze(0).numpy())
dt_s = time.perf_counter() - start_episode_t dt_s = time.perf_counter() - start_episode_t
busy_wait(1 / fps - dt_s) busy_wait(1 / fps - dt_s)
# wait before playing next episode
busy_wait(5)
if __name__ == "__main__": if __name__ == "__main__":
@ -613,7 +618,6 @@ if __name__ == "__main__":
help="Path to a yaml config you want to use for initializing a sim environment based on gym ", help="Path to a yaml config you want to use for initializing a sim environment based on gym ",
) )
parser_teleop = subparsers.add_parser("teleoperate", parents=[base_parser]) parser_teleop = subparsers.add_parser("teleoperate", parents=[base_parser])
parser_teleop.add_argument( parser_teleop.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
@ -705,7 +709,7 @@ if __name__ == "__main__":
default="lerobot/test", default="lerobot/test",
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).", help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
) )
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episode to replay.") parser_replay.add_argument("--episodes", nargs='+', type=int, default=[0], help="Indices of the episodes to replay.")
args = parser.parse_args() args = parser.parse_args()
@ -723,6 +727,8 @@ if __name__ == "__main__":
env_cfg = init_hydra_config(env_config_path) env_cfg = init_hydra_config(env_config_path)
env_fn = lambda: make_env(env_cfg, n_envs=1) env_fn = lambda: make_env(env_cfg, n_envs=1)
robot = None
if control_mode != 'replay':
# make robot # make robot
robot_overrides = ['~cameras', '~follower_arms'] robot_overrides = ['~cameras', '~follower_arms']
robot_cfg = init_hydra_config(robot_path, robot_overrides) robot_cfg = init_hydra_config(robot_path, robot_overrides)
@ -737,12 +743,12 @@ if __name__ == "__main__":
record(env_fn, robot, **kwargs) record(env_fn, robot, **kwargs)
elif control_mode == "replay": elif control_mode == "replay":
replay(env_fn, robot, **kwargs) replay(env_fn, **kwargs)
else: else:
raise ValueError(f"Invalid control mode: '{control_mode}', only valid modes are teleoperate, record and replay." ) raise ValueError(f"Invalid control mode: '{control_mode}', only valid modes are teleoperate, record and replay." )
if robot.is_connected: if robot and robot.is_connected:
# Disconnect manually to avoid a "Core dump" during process # Disconnect manually to avoid a "Core dump" during process
# termination due to camera threads not properly exiting. # termination due to camera threads not properly exiting.
robot.disconnect() robot.disconnect()