Apply suggestions from code review
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
parent
9451a10c00
commit
341effd58a
|
@ -43,17 +43,17 @@ print(f"average number of frames per episode: {dataset.num_samples / dataset.num
|
|||
print(f"frames per second used during data collection: {dataset.fps=}")
|
||||
print(f"keys to access images from cameras: {dataset.image_keys=}")
|
||||
|
||||
# access frame indices associated to first episode
|
||||
# Access frame indexes associated to first episode
|
||||
episode_index = 0
|
||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||
|
||||
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter,
|
||||
# like iterating through the dataset. Here we grab all the image frames.
|
||||
# LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working with the latter, like iterating through the dataset.
|
||||
# Here we grab all the image frames.
|
||||
frames = [dataset[idx]["observation.image"] for idx in range(from_idx, to_idx)]
|
||||
|
||||
# Video frames are now float32 range [0,1] channel first (c,h,w) to follow pytorch convention.
|
||||
# To view them, we convert to uint8 range [0,255]
|
||||
# Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention.
|
||||
# To visualize them, we convert to uint8 range [0,255]
|
||||
frames = [(frame * 255).type(torch.uint8) for frame in frames]
|
||||
# and to channel last (h,w,c).
|
||||
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
|
||||
|
@ -62,7 +62,7 @@ frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
|
|||
Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
|
||||
imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_0.mp4", frames, fps=dataset.fps)
|
||||
|
||||
# For many machine learning applications we need to load histories of past observations, or trajectorys of future actions.
|
||||
# For many machine learning applications we need to load the history of past observations or trajectories of future actions.
|
||||
# Our datasets can load previous and future frames for each key/modality,
|
||||
# using timestamps differences with the current loaded frame. For instance:
|
||||
delta_timestamps = {
|
||||
|
|
|
@ -32,15 +32,16 @@ local$ rerun lerobot_pusht_episode_0.rrd
|
|||
```
|
||||
|
||||
- Visualize data stored on a distant machine through streaming:
|
||||
(You need to forward the websocket port to the distant machine, with
|
||||
`ssh -L 9087:localhost:9087 username@remote-host`)
|
||||
```
|
||||
distant$ python lerobot/scripts/visualize_dataset.py \
|
||||
--repo-id lerobot/pusht \
|
||||
--episode-index 0 \
|
||||
--mode distant \
|
||||
--web-port 9090 \
|
||||
--ws-port 9087
|
||||
|
||||
local$ rerun ws://localhost:9090
|
||||
local$ rerun ws://localhost:9087
|
||||
```
|
||||
|
||||
"""
|
||||
|
@ -109,14 +110,14 @@ def visualize_dataset(
|
|||
|
||||
logging.info("Starting Rerun")
|
||||
|
||||
if mode == "local":
|
||||
spawn_local_viewer = not save
|
||||
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
|
||||
elif mode == "distant":
|
||||
rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
|
||||
else:
|
||||
if mode not in ["local", "distant"]:
|
||||
raise ValueError(mode)
|
||||
|
||||
spawn_local_viewer = mode == "local" and not save
|
||||
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
|
||||
if mode == "distant":
|
||||
rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
|
||||
|
||||
logging.info("Logging to Rerun")
|
||||
|
||||
if num_workers > 0:
|
||||
|
|
Loading…
Reference in New Issue