fixed replay function
This commit is contained in:
parent
498d9ef35c
commit
22df0b381d
|
@ -44,7 +44,7 @@ python lerobot/scripts/control_robot.py replay \
|
||||||
--fps 30 \
|
--fps 30 \
|
||||||
--root tmp/data \
|
--root tmp/data \
|
||||||
--repo-id $USER/koch_test \
|
--repo-id $USER/koch_test \
|
||||||
--episode 0
|
--episodes 0
|
||||||
```
|
```
|
||||||
|
|
||||||
- Record a full dataset in order to train a policy,
|
- Record a full dataset in order to train a policy,
|
||||||
|
@ -115,8 +115,6 @@ from lerobot.scripts.push_dataset_to_hub import (
|
||||||
########################################################################################
|
########################################################################################
|
||||||
# Utilities
|
# Utilities
|
||||||
########################################################################################
|
########################################################################################
|
||||||
|
|
||||||
|
|
||||||
def say(text, blocking=False):
|
def say(text, blocking=False):
|
||||||
# Check if mac, linux, or windows.
|
# Check if mac, linux, or windows.
|
||||||
if platform.system() == "Darwin":
|
if platform.system() == "Darwin":
|
||||||
|
@ -571,14 +569,17 @@ def record(
|
||||||
return lerobot_dataset
|
return lerobot_dataset
|
||||||
|
|
||||||
|
|
||||||
def replay(robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug"):
|
def replay(env, episodes: list, fps: int | None = None, root="data", repo_id="lerobot/debug"):
|
||||||
# TODO(rcadene): Add option to record logs
|
|
||||||
|
env = env()
|
||||||
local_dir = Path(root) / repo_id
|
local_dir = Path(root) / repo_id
|
||||||
if not local_dir.exists():
|
if not local_dir.exists():
|
||||||
raise ValueError(local_dir)
|
raise ValueError(local_dir)
|
||||||
|
|
||||||
dataset = LeRobotDataset(repo_id, root=root)
|
dataset = LeRobotDataset(repo_id, root=root)
|
||||||
items = dataset.hf_dataset.select_columns("action")
|
items = dataset.hf_dataset.select_columns("action")
|
||||||
|
for episode in episodes:
|
||||||
|
env.reset()
|
||||||
from_idx = dataset.episode_data_index["from"][episode].item()
|
from_idx = dataset.episode_data_index["from"][episode].item()
|
||||||
to_idx = dataset.episode_data_index["to"][episode].item()
|
to_idx = dataset.episode_data_index["to"][episode].item()
|
||||||
|
|
||||||
|
@ -588,11 +589,15 @@ def replay(robot: Robot, episode: int, fps: int | None = None, root="data", repo
|
||||||
start_episode_t = time.perf_counter()
|
start_episode_t = time.perf_counter()
|
||||||
|
|
||||||
action = items[idx]["action"]
|
action = items[idx]["action"]
|
||||||
robot.send_action(action)
|
|
||||||
|
env.step(action.unsqueeze(0).numpy())
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_episode_t
|
dt_s = time.perf_counter() - start_episode_t
|
||||||
busy_wait(1 / fps - dt_s)
|
busy_wait(1 / fps - dt_s)
|
||||||
|
|
||||||
|
# wait before playing next episode
|
||||||
|
busy_wait(5)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -613,7 +618,6 @@ if __name__ == "__main__":
|
||||||
help="Path to a yaml config you want to use for initializing a sim environment based on gym ",
|
help="Path to a yaml config you want to use for initializing a sim environment based on gym ",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
parser_teleop = subparsers.add_parser("teleoperate", parents=[base_parser])
|
parser_teleop = subparsers.add_parser("teleoperate", parents=[base_parser])
|
||||||
parser_teleop.add_argument(
|
parser_teleop.add_argument(
|
||||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||||
|
@ -705,7 +709,7 @@ if __name__ == "__main__":
|
||||||
default="lerobot/test",
|
default="lerobot/test",
|
||||||
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
||||||
)
|
)
|
||||||
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episode to replay.")
|
parser_replay.add_argument("--episodes", nargs='+', type=int, default=[0], help="Indices of the episodes to replay.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -723,6 +727,8 @@ if __name__ == "__main__":
|
||||||
env_cfg = init_hydra_config(env_config_path)
|
env_cfg = init_hydra_config(env_config_path)
|
||||||
env_fn = lambda: make_env(env_cfg, n_envs=1)
|
env_fn = lambda: make_env(env_cfg, n_envs=1)
|
||||||
|
|
||||||
|
robot = None
|
||||||
|
if control_mode != 'replay':
|
||||||
# make robot
|
# make robot
|
||||||
robot_overrides = ['~cameras', '~follower_arms']
|
robot_overrides = ['~cameras', '~follower_arms']
|
||||||
robot_cfg = init_hydra_config(robot_path, robot_overrides)
|
robot_cfg = init_hydra_config(robot_path, robot_overrides)
|
||||||
|
@ -737,12 +743,12 @@ if __name__ == "__main__":
|
||||||
record(env_fn, robot, **kwargs)
|
record(env_fn, robot, **kwargs)
|
||||||
|
|
||||||
elif control_mode == "replay":
|
elif control_mode == "replay":
|
||||||
replay(env_fn, robot, **kwargs)
|
replay(env_fn, **kwargs)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid control mode: '{control_mode}', only valid modes are teleoperate, record and replay." )
|
raise ValueError(f"Invalid control mode: '{control_mode}', only valid modes are teleoperate, record and replay." )
|
||||||
|
|
||||||
if robot.is_connected:
|
if robot and robot.is_connected:
|
||||||
# Disconnect manually to avoid a "Core dump" during process
|
# Disconnect manually to avoid a "Core dump" during process
|
||||||
# termination due to camera threads not properly exiting.
|
# termination due to camera threads not properly exiting.
|
||||||
robot.disconnect()
|
robot.disconnect()
|
||||||
|
|
Loading…
Reference in New Issue