Apply suggestions from code review

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Remi 2024-05-04 16:01:16 +02:00 committed by GitHub
parent 9451a10c00
commit 341effd58a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 14 deletions

View File

@ -43,17 +43,17 @@ print(f"average number of frames per episode: {dataset.num_samples / dataset.num
print(f"frames per second used during data collection: {dataset.fps=}")
print(f"keys to access images from cameras: {dataset.image_keys=}")
# access frame indices associated to first episode
# Access frame indexes associated to first episode
episode_index = 0
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter,
# like iterating through the dataset. Here we grab all the image frames.
# LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working with the latter, like iterating through the dataset.
# Here we grab all the image frames.
frames = [dataset[idx]["observation.image"] for idx in range(from_idx, to_idx)]
# Video frames are now float32 range [0,1] channel first (c,h,w) to follow pytorch convention.
# To view them, we convert to uint8 range [0,255]
# Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention.
# To visualize them, we convert to uint8 range [0,255]
frames = [(frame * 255).type(torch.uint8) for frame in frames]
# and to channel last (h,w,c).
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
@ -62,7 +62,7 @@ frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_0.mp4", frames, fps=dataset.fps)
# For many machine learning applications we need to load histories of past observations, or trajectorys of future actions.
# For many machine learning applications we need to load the history of past observations or trajectories of future actions.
# Our datasets can load previous and future frames for each key/modality,
# using timestamps differences with the current loaded frame. For instance:
delta_timestamps = {

View File

@ -32,15 +32,16 @@ local$ rerun lerobot_pusht_episode_0.rrd
```
- Visualize data stored on a distant machine through streaming:
(You need to forward the websocket port to the distant machine, with
`ssh -L 9087:localhost:9087 username@remote-host`)
```
distant$ python lerobot/scripts/visualize_dataset.py \
--repo-id lerobot/pusht \
--episode-index 0 \
--mode distant \
--web-port 9090 \
--ws-port 9087
local$ rerun ws://localhost:9090
local$ rerun ws://localhost:9087
```
"""
@ -109,14 +110,14 @@ def visualize_dataset(
logging.info("Starting Rerun")
if mode == "local":
spawn_local_viewer = not save
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
elif mode == "distant":
rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
else:
if mode not in ["local", "distant"]:
raise ValueError(mode)
spawn_local_viewer = mode == "local" and not save
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
if mode == "distant":
rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
logging.info("Logging to Rerun")
if num_workers > 0: